diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index 8563b5b6d3695878e4f65c131cff600d08451e4c..cd6839fe19e232b6319c55b19dc54f04c9102a9b 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -181,8 +181,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, if (in_format != out_format) { void* in_data = GetDataFromTensor(in, in_type); - const std::string key = - platform::CreateKey(in_tz, in_format, out_format, in_type); + std::string key = + platform::CreateKey(*dev_ctx, in_tz, in_format, out_format, in_type); platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx, cpu_engine, key); diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc index e51d94e4b1e05a9b394e96fd2c0e561b46453793..1eed49de784089a078b462e7ab0e47a456df3c1b 100644 --- a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc @@ -39,20 +39,15 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { const std::string& unique_name) : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - CreateKey(unique_name, MKLDNNGetDataType(), Ti)), + CreateKey(dev_ctx, unique_name, MKLDNNGetDataType(), Ti)), N(N), Ti(Ti), IC(IC), OC(OC) { // Create memory key without Ti because weights, bias and h0 memories // do not depend on Ti size but primitive and input/output memory do - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - memory_key_ = CreateKey(unique_name, MKLDNNGetDataType()); - } else { - memory_key_ = CreateKey(unique_name, MKLDNNGetDataType(), "-t:", - platform::ThreadIDasStr()); - } + memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded( + dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType())); // Is it int8 kernel const bool is_INT8 = std::is_same::value; diff --git a/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc index b7fd40f78ff9d3fb829c1a0d5c2cc91a62a8455c..11711bab81735efd0494d454c33cf5aa0f0274a3 100644 --- a/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc @@ -109,13 +109,8 @@ class MultiGRUHandler { const std::string unique_name = ctx.OutputName("Hidden"); // Create memory key without Ti because weights, bias and h0 memories // do not depend on Ti size but primitive and input/output memory do - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - memory_key_ = CreateKey(unique_name, MKLDNNGetDataType()); - } else { - memory_key_ = CreateKey(unique_name, MKLDNNGetDataType(), "-t:", - platform::ThreadIDasStr()); - } + memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded( + dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType())); key_ = memory_key_; key_.append("T").append(std::to_string(Ti_)); diff --git a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc index 98f368aa7a90859121a06e42705aee6355182b27..622d6685dfa718b1220ac4afbf67982b5acce188 100644 --- a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc @@ -48,7 +48,8 @@ class BatchNormMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(framework::vectorize(x->dims()), unique_name)) { + platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), + unique_name)) { if (!this->isCached()) { const float epsilon = ctx.Attr("epsilon"); const bool fuse_with_relu = ctx.Attr("fuse_with_relu"); @@ -89,7 +90,7 @@ class BatchNormMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, uniq_name)) { + platform::CreateKey(dev_ctx, dims, uniq_name)) { auto diff_dst_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), diff_fmt); auto src_md = diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index 78862a5559dab78dbce2af48c7b10febed39287a..63aa2357beea074c9ab6a11230d50a9ea114863a 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -159,9 +159,10 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { // If one of the multiple inputs of concat has an input size of 0, the // actual size of the multi_input will change std::string key = platform::CreateKey( - paddle::framework::vectorize(multi_input[0]->dims()), + dev_ctx, paddle::framework::vectorize(multi_input[0]->dims()), multi_input.size(), ctx.OutputName("Out"), dt, - platform::ThreadIDasStr(), dev_ctx.GetKeySuffix()); + platform::ThreadIDasStr()); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); const std::string key_prim = key + "@concat_p"; const std::string key_concat_pd = key + "@concat_pd"; diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 99175a73e288ed4412db28faf3efc370849be650..1fc0f14e5ddd9d21892bb1520579f1726ffc207b 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -95,7 +95,7 @@ class ConvMKLDNNHandlerT const std::string& unique_name) : platform::MKLDNNHandlerT( dev_ctx, mkldnn_engine, cpu_place, - platform::CreateKey(framework::vectorize(input->dims()), + platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), unique_name)) { if (!this->isCached()) { PADDLE_ENFORCE_EQ( @@ -521,8 +521,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type()); - std::string key = platform::CreateKey( - src_tz, src_dt, ctx.InputName("Input") + ctx.InputName("Filter")); + std::string key = + platform::CreateKey(dev_ctx, src_tz, src_dt, + ctx.InputName("Input") + ctx.InputName("Filter")); const std::string key_conv_pd = key + "@conv_pd"; bool need_s8_to_u8 = false; @@ -537,21 +538,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { // This is workaround for hacky implementation // of conv int8 mkl-dnn. Once conv fp32 and conv int8 // are merged/unified, this will disappear - std::string key_tid = ""; - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() == - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - key_tid = "-t:" + platform::ThreadIDasStr(); - } - - auto prim_key = key + key_tid + "@conv_p"; - auto dst_key = key + key_tid + "@dst_mem_p"; - auto src_key = key + key_tid + "@src_mem_p"; - auto weights_key = key + key_tid + "@weights_mem_p"; - auto bias_key = key + key_tid + "@bias_mem_p"; - auto user_src_key = key + key_tid + "@user_src_mem_p"; - auto user_residual_key = key + key_tid + "@user_residual_data_mem_p"; - auto src_reorder_key = key + key_tid + "@src_mem_preorder_p"; - auto residual_reorder_key = key + key_tid + "@residual_data_mem_preorder_p"; + auto key_tid = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + + auto prim_key = key_tid + "@conv_p"; + auto dst_key = key_tid + "@dst_mem_p"; + auto src_key = key_tid + "@src_mem_p"; + auto weights_key = key_tid + "@weights_mem_p"; + auto bias_key = key_tid + "@bias_mem_p"; + auto user_src_key = key_tid + "@user_src_mem_p"; + auto user_residual_key = key_tid + "@user_residual_data_mem_p"; + auto src_reorder_key = key_tid + "@src_mem_preorder_p"; + auto residual_reorder_key = key_tid + "@residual_data_mem_preorder_p"; conv_p = std::static_pointer_cast( dev_ctx.GetBlob(prim_key)); @@ -964,10 +961,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { // Get an unique name from "argument" name of "input" and "Filter" variable // as well as attributes of primitive to be created // This name will be used as key when saving info into device context - const std::string key = platform::CreateKey( - src_tz, ctx.InputName("Input") + ctx.InputName("Filter")); + std::string key = platform::CreateKey( + dev_ctx, src_tz, ctx.InputName("Input") + ctx.InputName("Filter")); const std::string key_conv_pd = key + "@fwd_pd"; + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); std::vector pipeline; // Create user memory descriptors @@ -1082,8 +1080,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { mkldnn::memory::format_tag out_format = weights_tz.size() == 6 ? mkldnn::memory::format_tag::goidhw : mkldnn::memory::format_tag::goihw; - const std::string key = - platform::CreateKey(weights_tz, filter_fmt, out_format, in_type); + std::string key = platform::CreateKey(dev_ctx, weights_tz, filter_fmt, + out_format, in_type); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); platform::ReorderMKLDNNHandler handler(weights_tz, filter_grad->type(), in_type, dev_ctx, mkldnn_engine, diff --git a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc index e9f32e7ac25d8ea8f90f8beaf4cdc5b6ff086cf0..1eb90451a6952944afc3faee335e2a010fb3c2de 100644 --- a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc @@ -172,9 +172,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_tz = paddle::framework::vectorize(output->dims()); // Get unique name for storing MKLDNN primitives - const std::string key = - platform::CreateKey(src_tz, ctx.OutputName("Output")); + platform::CreateKey(dev_ctx, src_tz, ctx.OutputName("Output")); std::vector pipeline; diff --git a/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc index e036fd9aba04b214c22f2f179de5ba5eb5dd277d..8d41b750972352df7c957c6295cab972f3031a2a 100644 --- a/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc @@ -67,8 +67,11 @@ class DeQuantOpKernel : public framework::OpKernel { mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type()); MKLDNNMemoryFormat src_fmt = input->format(); - std::string key = platform::CreateKey(platform::ThreadIDasStr(), src_dt, - src_tz, ctx.OutputName("Output")); + + std::string key = + platform::CreateKey(dev_ctx, src_dt, src_tz, ctx.OutputName("Output")); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + const std::string key_prim = key + "@r"; const std::string key_src_mem = key + "@s"; const std::string key_dst_mem = key + "@d"; diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 9e0bdeee6b38b127e96fdbe0015a228ac720750b..c817ef7269fef12da8a43e44187cdf9a959bc9a3 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -361,8 +361,9 @@ class FCPrimitiveFactory { void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx) { - const std::string key = - platform::CreateKey(platform::ThreadIDasStr(), dev_ctx.GetKeySuffix()); + std::string key = platform::CreateKey(dev_ctx); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + const std::string weights_key = key + ctx.InputName("W"); const std::string bias_key = key + ctx.InputName("Bias"); dev_ctx.SetBlob(weights_key, weights_); @@ -532,10 +533,11 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input, const Tensor* w, const Tensor* bias, LoDTensor* output, bool fuse_relu, bool force_fp32_output) { auto& dev_ctx = ctx.template device_context(); - const std::string prim_key = platform::CreateKey( - platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), input->format(), - input->dims()[0], framework::vectorize(w->dims()), - ctx.OutputName("Out")); + std::string prim_key = platform::CreateKey( + dev_ctx, input->format(), input->dims()[0], + framework::vectorize(w->dims()), ctx.OutputName("Out")); + prim_key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, prim_key); + constexpr bool is_int8 = std::is_same::value || std::is_same::value; bool is_bfloat16 = std::is_same::value; diff --git a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc index 22261e948aa7b6799841afd72d8c1a6188382309..65dcb328f20839d4dc9f37e1b7175a4a0245e99e 100644 --- a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc @@ -30,7 +30,7 @@ class LayerNormMKLDNNHandler const std::string& uniq_name) : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, uniq_name)) { + platform::CreateKey(dev_ctx, dims, uniq_name)) { if (!this->isCached()) { auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); if (!is_test) { diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index 9f1fcf5bd0fbeb49a114c37c6c39bd1612a3359e..fddc4b4b2e5596c2f5fa6167869deb7d7cacf600 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -336,9 +336,8 @@ static std::shared_ptr> GetPrimitiveFactory( const auto& out_name = ctx.OutputName("Out"); const auto& dev_ctx = ctx.template device_context(); const auto batch_size = ctx.Input("X")->dims()[0]; - - const std::string key = platform::CreateKey( - platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), batch_size, out_name); + std::string key = platform::CreateKey(dev_ctx, batch_size, out_name); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); auto factory = std::static_pointer_cast>(dev_ctx.GetBlob(key)); diff --git a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc index 5abb7bf406a979bd9aedf4bf7e7d713b82dbba69..173a7cd867db40623f367f4ac9962f7860f0cb2b 100644 --- a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc @@ -296,9 +296,11 @@ std::shared_ptr> GetPrimitiveFactory( const MKLDNNDeviceContext &dev_ctx, const ExecutionContext &ctx, const Tensor *input_x, const Tensor *input_y, const mkldnn::engine &mkldnn_engine) { - const std::string key = platform::CreateKey( - input_x->type(), framework::vectorize(input_x->dims()), input_y->type(), - framework::vectorize(input_y->dims()), ctx.OutputName("Out")); + std::string key = platform::CreateKey( + dev_ctx, input_x->type(), framework::vectorize(input_x->dims()), + input_y->type(), framework::vectorize(input_y->dims()), + ctx.OutputName("Out")); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); auto prim_creator = std::static_pointer_cast>( dev_ctx.GetBlob(key)); diff --git a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc index 4e689f5bccf4b4f3925eec18710faa87fe47498a..9488a1a4405a46e6666d537dd9e441125fbfbaa9 100644 --- a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc @@ -140,7 +140,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { // Get an unique name from "argument" name of "Out" variable // This name will be used as key when referring info from device context const std::string key = platform::CreateKey( - diff_src_tz, pooling_type, ksize, strides, paddings, + dev_ctx, diff_src_tz, pooling_type, ksize, strides, paddings, memory::data_type::f32, in_x->format(), ctx.InputName("Out")); platform::PoolingMKLDNNHandler handler( diff --git a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc index e5dedd403f39f18ae7002296d589947b887d1430..0dbada58e4c6daed80e68a7f0f58ff63b2d1b1dc 100644 --- a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc @@ -64,9 +64,11 @@ class QuantOpKernel : public framework::OpKernel { bool is_negative_input = ctx.Attr("is_negative_input"); bool bfloat16 = ctx.Attr("bfloat16"); - std::string key = platform::CreateKey( - platform::ThreadIDasStr(), src_tz, scale_data, scale_shift, - is_negative_input, ctx.OutputName("Output")); + std::string key = + platform::CreateKey(dev_ctx, src_tz, scale_data, scale_shift, + is_negative_input, ctx.OutputName("Output")); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + const std::string key_prim = key + "@r"; const std::string key_src_mem = key + "@s"; const std::string key_dst_mem = key + "@d"; diff --git a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc index 4666e5b74a5cc252b078a434c58f9ac6f3f2891a..7c906a4ddab62ea6f2e0fda0c48229ad220de858 100644 --- a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc @@ -65,9 +65,9 @@ class ReQuantOpKernel : public framework::OpKernel { float reorder_scale = scale_out / scale_in; - std::string key = - platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_in, - scale_out, ctx.OutputName("Output")); + std::string key = platform::CreateKey(dev_ctx, src_tz, scale_in, scale_out, + ctx.OutputName("Output")); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); const std::string key_prim = key + "@r"; const std::string key_src_mem = key + "@s"; const std::string key_dst_mem = key + "@d"; diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index 9d9e1e2d8ded5190177edbb3eec7dfef6f875abf..3eb2e7084a0b07d20380d49012e6fe20973f5335 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -53,8 +53,8 @@ class SoftmaxMKLDNNHandler mkldnn::softmax_backward>( dev_ctx, mkldnn_engine, cpu_place, // Softmax may be inplace then uniq_name is no longer unique - platform::CreateKey(framework::vectorize(input->dims()), axis, - uniq_name)) { + platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), + axis, uniq_name)) { if (!this->isCached()) { PADDLE_ENFORCE_EQ( input->dims(), output->dims(), @@ -78,7 +78,7 @@ class SoftmaxMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, axis, uniq_name)) { + platform::CreateKey(dev_ctx, dims, axis, uniq_name)) { auto data_softmax_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); auto diff_softmax_md = diff --git a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc index 4df7818072f0538305808cd14606ae45ea84238d..90519caa40f2b402f5348ac47a66cc4a6950f122 100644 --- a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc @@ -54,7 +54,8 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT { : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(framework::vectorize(z->dims()), uniq_name)), + platform::CreateKey(dev_ctx, framework::vectorize(z->dims()), + uniq_name)), num_inputs_(0) { for (size_t i = 0; i < in_vars.size(); i++) { srcs_suffix_.push_back(std::string("-") + std::to_string(i)); @@ -184,8 +185,9 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel { // For in-place execution which sum does not have we need to fake it // so from oneDNN dst memory we reorder data into input if (in_place) { - const std::string reorder_key = platform::CreateKey( - framework::vectorize(output->dims()), ctx.OutputName("Out") + "-I"); + const std::string reorder_key = + platform::CreateKey(dev_ctx, framework::vectorize(output->dims()), + ctx.OutputName("Out") + "-I"); auto& in_out = in_vars[0]->Get(); auto output_tz = framework::vectorize(output->dims()); diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index 28cdd8413ab134224b72abd9a14fd1398784e056..feda5645b4cfa2bf580cc5bcefbc41d124dd3cc5 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -48,7 +48,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { auto nchw_tz = paddle::framework::vectorize(input->dims()); - const std::string key = platform::CreateKey(nchw_tz, ctx.OutputName("Out")); + const std::string key = + platform::CreateKey(dev_ctx, nchw_tz, ctx.OutputName("Out")); platform::TransposeMKLDNNHandler handler(nchw_tz, axis, dev_ctx, mkldnn_engine, key); @@ -103,7 +104,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto nchw_tz = paddle::framework::vectorize(out_grad->dims()); const std::string key = platform::CreateKey( - nchw_tz, ctx.OutputName(framework::GradVarName("X"))); + dev_ctx, nchw_tz, ctx.OutputName(framework::GradVarName("X"))); platform::TransposeMKLDNNHandler handler(nchw_tz, reversed_axis, dev_ctx, mkldnn_engine, key); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index de4c4a8363552d4ddc61de31052c997fc76a39c8..56438a95f2a8907bfb13bd192a9eb30e5082b4be 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -534,6 +534,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext { void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; } const std::string& GetKeySuffix(void) const { return key_suffix_; } + // Disable adding thread ID to the key + void DisableThreadInfoInKey(void) { key_attach_thread_id_ = false; }; + bool IsThreadIdUsedInKey(void) const { return key_attach_thread_id_; }; + // Prevent next ResetBlobMap() void BlockNextCacheClearing(); @@ -556,6 +560,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext { std::shared_ptr p_mutex_; bool block_next_cache_clearing_ = false; std::string key_suffix_; // Key identifying current Executor + bool key_attach_thread_id_ = true; }; #endif diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 797ff42f3c201458fd02caa445a9f5336a3cdb19..59a95e34c5478ef0d2f2e7fd5a97a6198d167790 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -428,11 +428,6 @@ inline void AppendKey(std::string* key, const std::vector& dims) { } } -inline unsigned int HashPointer(uintptr_t ptr) { - // Get four less meaningful digits in decimal numerals - return ptr % 1000; -} - // If MKLDNN build and CPU place then register suffix in DeviceContext inline void AttachPointerHashToMKLDNNKey(void* ptr, const platform::Place& place) { @@ -440,20 +435,34 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr, platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::MKLDNNDeviceContext* dev_ctx = (platform::MKLDNNDeviceContext*)pool.Get(place); - dev_ctx->SetKeySuffix("E" + std::to_string(platform::HashPointer( - reinterpret_cast(ptr)))); + dev_ctx->SetKeySuffix("E" + + std::to_string(reinterpret_cast(ptr))); + // When NaiveExecutor/Executor is used no info on thread id is needed in a + // key + dev_ctx->DisableThreadInfoInKey(); } } template -inline std::string CreateKey(ArgTypes&&... args) { +inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx, + ArgTypes&&... args) { std::string key; key.reserve(64); using expand_type = int[]; expand_type{0, (AppendKey(&key, std::forward(args)), 0)...}; + key += dev_ctx.GetKeySuffix(); return key; } +inline std::string ExtendKeyWithThreadInfoIfNeeded( + const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) { + return ((dev_ctx.IsThreadIdUsedInKey() == true) && + (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() == + platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default)) + ? key + "-t:" + ThreadIDasStr() + : key; +} + inline std::vector> ToMkldnnPadding( const std::vector& paddings) { if (paddings.size() == 6) { diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 96f8fd29c7964c0a21156413f27dcafea9a1eaea..e884d879ffa23b455e0e7289aa9cacfde0183b0a 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -43,16 +43,10 @@ class MKLDNNHandlerT { engine_(engine), place_(cpu_place), key_common_(base_key), + key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)), fwd_pd_(nullptr), bwd_pd_(nullptr) { platform::MKLDNNDeviceContext::tls().log_lib_version(); - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - key_ = key_common_; - } else { - key_ = key_common_ + "-t:" + ThreadIDasStr(); - } - key_ += dev_ctx.GetKeySuffix(); } std::shared_ptr AcquireForwardPrimitive() { @@ -300,8 +294,8 @@ class MKLDNNHandlerT { const MKLDNNDeviceContext& dev_ctx_; mkldnn::engine engine_; platform::Place place_; - std::string key_; std::string key_common_; + std::string key_; std::shared_ptr fwd_pd_; std::shared_ptr bwd_pd_; }; @@ -311,15 +305,11 @@ class MKLDNNHandler { public: MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, const std::string& base_key) - : dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) { + : dev_ctx_(dev_ctx), + engine_(engine), + key_common_(base_key), + key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)) { platform::MKLDNNDeviceContext::tls().log_lib_version(); - if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != - platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - key_ = key_common_; - } else { - key_ = key_common_ + "-t:" + ThreadIDasStr(); - } - key_ += dev_ctx.GetKeySuffix(); } std::shared_ptr AcquireSrcMemory( @@ -497,8 +487,8 @@ class MKLDNNHandler { protected: const MKLDNNDeviceContext& dev_ctx_; mkldnn::engine engine_; - std::string key_; std::string key_common_; + std::string key_; }; template @@ -513,7 +503,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey( - framework::vectorize(x->dims()), + dev_ctx, framework::vectorize(x->dims()), uniq_name + (algo == dnnl::algorithm::binary_mul ? "M" : ""))) { // bradcasting combined with in-place may require auto rankdiff = x->dims().size() - y->dims().size(); @@ -616,7 +606,7 @@ class ActivationMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, "a", algorithm, unique_name)) { + platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) { auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training, @@ -634,7 +624,7 @@ class ActivationMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, "a", algorithm, unique_name)) { + platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) { auto diff_dst_md = platform::MKLDNNMemDesc( dims, platform::MKLDNNGetDataType(), diff_fmt); auto src_md = @@ -665,7 +655,7 @@ class LRNMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, mkldnn_engine, cpu_place, - platform::CreateKey(framework::vectorize(input->dims()), + platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), unique_name)) { if (!this->isCached()) { const int n = ctx.Attr("n"); @@ -701,7 +691,7 @@ class LRNMKLDNNHandler : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, unique_name)) { + platform::CreateKey(dev_ctx, dims, unique_name)) { auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); auto diff_md = @@ -741,7 +731,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(framework::vectorize(input->dims()), + platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), framework::ToMKLDNNDataType(input->type()), unique_name)) { if (!this->isCached()) { @@ -850,7 +840,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(diff_src_dims, dt, unique_name)) { + platform::CreateKey(dev_ctx, diff_src_dims, dt, unique_name)) { auto diff_dst_md = mkldnn::memory::desc( diff_dst_dims, platform::MKLDNNGetDataType(), diff_dst_fmt); auto diff_src_md =