未验证 提交 bd1d6d3b 编写于 作者: J Jacek Czaja 提交者: GitHub

extends oneDNN caching keys so caching objects are unique to executor/predictor (#28758)

上级 3d0ff8ee
...@@ -557,6 +557,7 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) { ...@@ -557,6 +557,7 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) {
} }
} }
} }
platform::AttachPointerHashToMKLDNNKey(this, place_);
#else #else
LOG(WARNING) LOG(WARNING)
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option"; << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
......
...@@ -44,6 +44,9 @@ void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc, ...@@ -44,6 +44,9 @@ void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc,
} }
void NaiveExecutor::Run() { void NaiveExecutor::Run() {
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
for (auto &op : ops_) { for (auto &op : ops_) {
VLOG(4) << std::this_thread::get_id() << " run " VLOG(4) << std::this_thread::get_id() << " run "
<< op->DebugStringEx(scope_) << " on scope " << scope_; << op->DebugStringEx(scope_) << " on scope " << scope_;
......
...@@ -160,7 +160,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -160,7 +160,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::string key = platform::CreateKey( std::string key = platform::CreateKey(
paddle::framework::vectorize<int>(multi_input[0]->dims()), paddle::framework::vectorize<int>(multi_input[0]->dims()),
multi_input.size(), ctx.OutputName("Out"), dt, multi_input.size(), ctx.OutputName("Out"), dt,
platform::ThreadIDasStr()); platform::ThreadIDasStr(), dev_ctx.GetKeySuffix());
const std::string key_prim = key + "@concat_p"; const std::string key_prim = key + "@concat_p";
const std::string key_concat_pd = key + "@concat_pd"; const std::string key_concat_pd = key + "@concat_pd";
......
...@@ -361,7 +361,8 @@ class FCPrimitiveFactory { ...@@ -361,7 +361,8 @@ class FCPrimitiveFactory {
void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx, void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx) { const ExecutionContext& ctx) {
const std::string key = platform::CreateKey(platform::ThreadIDasStr()); const std::string key =
platform::CreateKey(platform::ThreadIDasStr(), dev_ctx.GetKeySuffix());
const std::string weights_key = key + ctx.InputName("W"); const std::string weights_key = key + ctx.InputName("W");
const std::string bias_key = key + ctx.InputName("Bias"); const std::string bias_key = key + ctx.InputName("Bias");
dev_ctx.SetBlob(weights_key, weights_); dev_ctx.SetBlob(weights_key, weights_);
...@@ -532,8 +533,9 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input, ...@@ -532,8 +533,9 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input,
bool fuse_relu, bool force_fp32_output) { bool fuse_relu, bool force_fp32_output) {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const std::string prim_key = platform::CreateKey( const std::string prim_key = platform::CreateKey(
platform::ThreadIDasStr(), input->format(), input->dims()[0], platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), input->format(),
framework::vectorize<int>(w->dims()), ctx.OutputName("Out")); input->dims()[0], framework::vectorize<int>(w->dims()),
ctx.OutputName("Out"));
constexpr bool is_int8 = constexpr bool is_int8 =
std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value; std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value;
bool is_bfloat16 = std::is_same<T_in, paddle::platform::bfloat16>::value; bool is_bfloat16 = std::is_same<T_in, paddle::platform::bfloat16>::value;
......
...@@ -337,8 +337,8 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory( ...@@ -337,8 +337,8 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto batch_size = ctx.Input<Tensor>("X")->dims()[0]; const auto batch_size = ctx.Input<Tensor>("X")->dims()[0];
const std::string key = const std::string key = platform::CreateKey(
platform::CreateKey(platform::ThreadIDasStr(), batch_size, out_name); platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), batch_size, out_name);
auto factory = auto factory =
std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key)); std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key));
......
...@@ -535,6 +535,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -535,6 +535,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// Remove all entries from the blob map // Remove all entries from the blob map
void ResetBlobMap(); void ResetBlobMap();
// Set a suffix to be added to key
void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; }
const std::string& GetKeySuffix(void) const { return key_suffix_; }
// Prevent next ResetBlobMap() // Prevent next ResetBlobMap()
void BlockNextCacheClearing(); void BlockNextCacheClearing();
...@@ -556,6 +560,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -556,6 +560,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
std::shared_ptr<BlobMap> p_blobmap_; std::shared_ptr<BlobMap> p_blobmap_;
std::shared_ptr<std::mutex> p_mutex_; std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false; bool block_next_cache_clearing_ = false;
std::string key_suffix_; // Key identifying current Executor
}; };
#endif #endif
......
...@@ -433,6 +433,23 @@ inline void AppendKey(std::string* key, const std::vector<T>& dims) { ...@@ -433,6 +433,23 @@ inline void AppendKey(std::string* key, const std::vector<T>& dims) {
} }
} }
inline unsigned int HashPointer(uintptr_t ptr) {
// Get four less meaningful digits in decimal numerals
return ptr % 1000;
}
// If MKLDNN build and CPU place then register suffix in DeviceContext
inline void AttachPointerHashToMKLDNNKey(void* ptr,
const platform::Place& place) {
if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->SetKeySuffix("E" + std::to_string(platform::HashPointer(
reinterpret_cast<uintptr_t>(ptr))));
}
}
template <typename... ArgTypes> template <typename... ArgTypes>
inline std::string CreateKey(ArgTypes&&... args) { inline std::string CreateKey(ArgTypes&&... args) {
std::string key; std::string key;
......
...@@ -51,6 +51,7 @@ class MKLDNNHandlerT { ...@@ -51,6 +51,7 @@ class MKLDNNHandlerT {
} else { } else {
key_ = key_common_ + "-t:" + ThreadIDasStr(); key_ = key_common_ + "-t:" + ThreadIDasStr();
} }
key_ += dev_ctx.GetKeySuffix();
} }
std::shared_ptr<TForward> AcquireForwardPrimitive() { std::shared_ptr<TForward> AcquireForwardPrimitive() {
...@@ -316,6 +317,7 @@ class MKLDNNHandler { ...@@ -316,6 +317,7 @@ class MKLDNNHandler {
} else { } else {
key_ = key_common_ + "-t:" + ThreadIDasStr(); key_ = key_common_ + "-t:" + ThreadIDasStr();
} }
key_ += dev_ctx.GetKeySuffix();
} }
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册