diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index e581e23a608c3c44066acc30b131eb62d4c30388..2058b868e29d36006ceb497eda89504a96a7bf9b 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -560,17 +560,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } - static std::unordered_map>> md_map; + static std::unordered_map> md_map; bool md_reuse = true; - std::vector> mds(8, nullptr); - std::vector> none_mds = {}; - //auto user_src_md_key = key + "@user_src_md"; - if (GetMdMap(md_map, key) == none_mds){ + auto user_src_md_key = key + "@user_src_md"; + if (GetMdMap(md_map, user_src_md_key) == nullptr){ md_reuse = false; //we suppose all mds are reused if the first md is in the map. - } else{ - mds = GetMdMap(md_map, key); } - //auto user_weights_md_key = key + "@user_weights_md"; + auto user_weights_md_key = key + "@user_weights_md"; std::shared_ptr user_src_md; std::shared_ptr user_weights_md; std::vector pipeline; @@ -582,16 +578,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { user_weights_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( {weights_tz}, platform::MKLDNNGetDataType(), (g == 1) ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw))); - - mds[0] = user_src_md; - mds[1] = user_weights_md; - //SetMdMap(md_map, user_src_md_key, user_src_md); - //SetMdMap(md_map, user_weights_md_key, user_weights_md); + + SetMdMap(md_map, user_src_md_key, user_src_md); + SetMdMap(md_map, user_weights_md_key, user_weights_md); } else{ - user_src_md = mds[0]; - user_weights_md = mds[1]; - //user_src_md = GetMdMap(md_map, user_src_md_key); - //user_weights_md = GetMdMap(md_map, user_weights_md_key); + user_src_md = GetMdMap(md_map, user_src_md_key); + user_weights_md = GetMdMap(md_map, user_weights_md_key); } /* create memory descriptor for convolution without specified format @@ -605,10 +597,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::shared_ptr conv_pd; auto bias_tz = paddle::framework::vectorize2int(bias->dims()); - //auto src_md_key = key + "@src_md"; - //auto weights_md_key = key + "@weights_md_key"; - //auto dst_md_key = key + "@dst_md_key"; - //auto bias_md_key = key + "@bias_md_key"; + auto src_md_key = key + "@src_md"; + auto weights_md_key = key + "@weights_md_key"; + auto dst_md_key = key + "@dst_md_key"; + auto bias_md_key = key + "@bias_md_key"; std::shared_ptr src_md; std::shared_ptr weights_md; std::shared_ptr dst_md; @@ -629,19 +621,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { if(force_fp32_output) dst_dt = paddle::framework::ToMKLDNNDataType(std::type_index(typeid(float))); dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format))); - mds[2] = src_md; - mds[3] = weights_md; - mds[4] = dst_md; - //SetMdMap(md_map, src_md_key, src_md); - //SetMdMap(md_map, weights_md_key, weights_md); - //SetMdMap(md_map, dst_md_key, dst_md); + SetMdMap(md_map, src_md_key, src_md); + SetMdMap(md_map, weights_md_key, weights_md); + SetMdMap(md_map, dst_md_key, dst_md); } else{ - src_md = mds[2]; - weights_md = mds[3]; - dst_md = mds[4]; - //src_md = GetMdMap(md_map, src_md_key); - //weights_md = GetMdMap(md_map, weights_md_key); - //dst_md = GetMdMap(md_map, dst_md_key); + src_md = GetMdMap(md_map, src_md_key); + weights_md = GetMdMap(md_map, weights_md_key); + dst_md = GetMdMap(md_map, dst_md_key); } // create a conv primitive descriptor and save it for usage in backward @@ -650,11 +636,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { if(!md_reuse){ bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( bias_tz, memory::data_type::s32, memory::format::x))); - mds[5] = bias_md; - //SetMdMap(md_map, bias_md_key, bias_md); + SetMdMap(md_map, bias_md_key, bias_md); } else{ - bias_md = mds[5]; - //bias_md = GetMdMap(md_map, bias_md_key); + bias_md = GetMdMap(md_map, bias_md_key); } conv_pd = ConvFwdPrimitiveDesc(*src_md, *weights_md, *bias_md, *dst_md, @@ -675,19 +659,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { weights_tz, platform::MKLDNNGetDataType(), chosen_memory_format))); dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format))); - mds[2] = src_md; - mds[3] = weights_md; - mds[4] = dst_md; - //SetMdMap(md_map, src_md_key, src_md); - //SetMdMap(md_map, weights_md_key, weights_md); - //SetMdMap(md_map, dst_md_key, dst_md); + SetMdMap(md_map, src_md_key, src_md); + SetMdMap(md_map, weights_md_key, weights_md); + SetMdMap(md_map, dst_md_key, dst_md); } else{ - src_md = mds[2]; - weights_md = mds[3]; - dst_md = mds[4]; - //src_md = GetMdMap(md_map, src_md_key); - //weights_md = GetMdMap(md_map, weights_md_key); - //dst_md = GetMdMap(md_map, dst_md_key); + src_md = GetMdMap(md_map, src_md_key); + weights_md = GetMdMap(md_map, weights_md_key); + dst_md = GetMdMap(md_map, dst_md_key); } // create a conv primitive descriptor and save it for usage in backward if (bias) { @@ -695,11 +673,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { if(!md_reuse){ bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( bias_tz, platform::MKLDNNGetDataType(), memory::format::x))); - mds[5] = bias_md; - //SetMdMap(md_map, bias_md_key, bias_md); + SetMdMap(md_map, bias_md_key, bias_md); } else{ - bias_md = mds[5]; - //bias_md = GetMdMap(md_map, bias_md_key); + bias_md = GetMdMap(md_map, bias_md_key); } conv_pd = ConvFwdPrimitiveDesc(*src_md, *weights_md, *bias_md, *dst_md, strides, paddings, mkldnn_engine, @@ -738,7 +714,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::shared_ptr dst_memory_p; bool need_s8_to_u8 = false; - //auto user_residual_md_key = key + "@user_residual_md"; + auto user_residual_md_key = key + "@user_residual_md"; if(fuse_residual_conn) { auto residual_param = ctx.Input("ResidualData"); PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), @@ -754,11 +730,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { paddle::framework::ToMKLDNNDataType(residual_param->type()); user_residual_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( residual_data_tz, residual_data_type, residual_param->format()))); - mds[6] = user_residual_md; - //SetMdMap(md_map, user_residual_md_key, user_residual_md); + SetMdMap(md_map, user_residual_md_key, user_residual_md); } else{ - user_residual_md = mds[6]; - //user_residual_md = GetMdMap(md_map, user_residual_md_key); + user_residual_md = GetMdMap(md_map, user_residual_md_key); } if(is_INT8){ PADDLE_ENFORCE( @@ -844,18 +818,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { // create convolution op primitive std::shared_ptr conv_p; //auto scale_bias_key = key + "@scale_bias"; - //auto user_bias_md_key = key + "@user_bias_md"; + auto user_bias_md_key = key + "@user_bias_md"; if (bias) { const float* bias_data = bias->data(); std::shared_ptr user_bias_md; if(!md_reuse){ user_bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( {bias_tz}, platform::MKLDNNGetDataType(), memory::format::x))); - mds[7] = user_bias_md; - //SetMdMap(md_map, user_bias_md_key, user_bias_md); + SetMdMap(md_map, user_bias_md_key, user_bias_md); } else{ - user_bias_md = mds[7]; - //user_bias_md = GetMdMap(md_map, user_bias_md_key); + user_bias_md = GetMdMap(md_map, user_bias_md_key); } auto user_bias_memory_p = handler.AcquireBiasMemory(*user_bias_md, to_void_cast(bias_data)); @@ -891,7 +863,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } SetScaleMap(scale_map, key, scale_datas); - SetMdMap(md_map, key, mds); // push primitive to stream and wait until it's executed pipeline.push_back(*conv_p); @@ -927,8 +898,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { return {{0.0f}}; } - void SetMdMap(std::unordered_map>> &md_map, - const std::string& name, std::vector> mds) const { + void SetMdMap(std::unordered_map> &md_map, + const std::string& name, std::shared_ptr mds) const { auto it = md_map.find(name); if (it == md_map.end()) { md_map[name] = mds; // create new blob @@ -938,13 +909,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { return; } - std::vector> GetMdMap(std::unordered_map>> md_map, + std::shared_ptr GetMdMap(std::unordered_map> md_map, const std::string& name) const { auto it = md_map.find(name); if (it != md_map.end()) { return (*it).second; } - return {}; + return nullptr; } mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn,