未验证 提交 bd1d6d3b 编写于 作者: J Jacek Czaja 提交者: GitHub

extends oneDNN caching keys so caching objects are unique to executor/predictor (#28758)

上级 3d0ff8ee
......@@ -557,6 +557,7 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) {
}
}
}
platform::AttachPointerHashToMKLDNNKey(this, place_);
#else
LOG(WARNING)
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
......
......@@ -44,6 +44,9 @@ void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc,
}
void NaiveExecutor::Run() {
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
for (auto &op : ops_) {
VLOG(4) << std::this_thread::get_id() << " run "
<< op->DebugStringEx(scope_) << " on scope " << scope_;
......
......@@ -160,7 +160,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::string key = platform::CreateKey(
paddle::framework::vectorize<int>(multi_input[0]->dims()),
multi_input.size(), ctx.OutputName("Out"), dt,
platform::ThreadIDasStr());
platform::ThreadIDasStr(), dev_ctx.GetKeySuffix());
const std::string key_prim = key + "@concat_p";
const std::string key_concat_pd = key + "@concat_pd";
......
......@@ -361,7 +361,8 @@ class FCPrimitiveFactory {
void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx) {
const std::string key = platform::CreateKey(platform::ThreadIDasStr());
const std::string key =
platform::CreateKey(platform::ThreadIDasStr(), dev_ctx.GetKeySuffix());
const std::string weights_key = key + ctx.InputName("W");
const std::string bias_key = key + ctx.InputName("Bias");
dev_ctx.SetBlob(weights_key, weights_);
......@@ -532,8 +533,9 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input,
bool fuse_relu, bool force_fp32_output) {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const std::string prim_key = platform::CreateKey(
platform::ThreadIDasStr(), input->format(), input->dims()[0],
framework::vectorize<int>(w->dims()), ctx.OutputName("Out"));
platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), input->format(),
input->dims()[0], framework::vectorize<int>(w->dims()),
ctx.OutputName("Out"));
constexpr bool is_int8 =
std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value;
bool is_bfloat16 = std::is_same<T_in, paddle::platform::bfloat16>::value;
......
......@@ -337,8 +337,8 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto batch_size = ctx.Input<Tensor>("X")->dims()[0];
const std::string key =
platform::CreateKey(platform::ThreadIDasStr(), batch_size, out_name);
const std::string key = platform::CreateKey(
platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), batch_size, out_name);
auto factory =
std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key));
......
......@@ -535,6 +535,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// Remove all entries from the blob map
void ResetBlobMap();
// Set a suffix to be added to key
void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; }
const std::string& GetKeySuffix(void) const { return key_suffix_; }
// Prevent next ResetBlobMap()
void BlockNextCacheClearing();
......@@ -556,6 +560,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
std::shared_ptr<BlobMap> p_blobmap_;
std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false;
std::string key_suffix_; // Key identifying current Executor
};
#endif
......
......@@ -433,6 +433,23 @@ inline void AppendKey(std::string* key, const std::vector<T>& dims) {
}
}
inline unsigned int HashPointer(uintptr_t ptr) {
// Get four less meaningful digits in decimal numerals
return ptr % 1000;
}
// If MKLDNN build and CPU place then register suffix in DeviceContext
inline void AttachPointerHashToMKLDNNKey(void* ptr,
const platform::Place& place) {
if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->SetKeySuffix("E" + std::to_string(platform::HashPointer(
reinterpret_cast<uintptr_t>(ptr))));
}
}
template <typename... ArgTypes>
inline std::string CreateKey(ArgTypes&&... args) {
std::string key;
......
......@@ -51,6 +51,7 @@ class MKLDNNHandlerT {
} else {
key_ = key_common_ + "-t:" + ThreadIDasStr();
}
key_ += dev_ctx.GetKeySuffix();
}
std::shared_ptr<TForward> AcquireForwardPrimitive() {
......@@ -316,6 +317,7 @@ class MKLDNNHandler {
} else {
key_ = key_common_ + "-t:" + ThreadIDasStr();
}
key_ += dev_ctx.GetKeySuffix();
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册