diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index f1ecfe41b961620a07339ab4a75456acddd4203f..aec34421b6c9f4f28a896cbfb07f49dc1396beaf 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -376,7 +376,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { static std::unordered_map> scale_map; //scale_map.insert({key_conv_pd,{1.0f}}); //scale_map[key_conv_pd]={0.1f}; - bool scale_reuse = false; + bool scale_reuse = true; auto scale_in_key = key + "@scale_in"; auto scale_weights_key = key + "@scale_weights"; auto scale_out_key = key + "@scale_out"; @@ -389,14 +389,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector scale_in_eltwise_data; std::vector output_shift_scale; std::vector sum_scale = {1.0f}; - std::vector none_scale = {0}; + std::vector none_scale = {0.0f}; if (is_INT8 && GetScaleMap(scale_map, scale_in_key) == none_scale){ - scale_reuse = true; + scale_reuse = false; } //std::cout<<"scale_reuse = "<1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; scale_in_data = {*(scale_in->data())}; @@ -440,13 +440,31 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } - + static std::unordered_map> md_map; + bool md_reuse = true; + auto user_src_md_key = key + "@user_src_md"; + if (GetMdMap(md_map, user_src_md_key) == nullptr){ + md_reuse = false; //we suppose all mds are reused if the first md is in the map. + } + auto user_weights_md_key = key + "@user_weights_md"; + std::shared_ptr user_src_md; + std::shared_ptr user_weights_md; std::vector pipeline; - auto user_src_md = platform::MKLDNNMemDesc( - {src_tz}, paddle::framework::ToMKLDNNDataType(input->type()), input->format()); - auto user_weights_md = platform::MKLDNNMemDesc( - {weights_tz}, platform::MKLDNNGetDataType(), - (g == 1) ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw); +//std::cout<<"md_reuse = "<type()), input->format()))); + user_weights_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( + {weights_tz}, platform::MKLDNNGetDataType(), + (g == 1) ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw))); + + SetMdMap(md_map, user_src_md_key, user_src_md); + SetMdMap(md_map, user_weights_md_key, user_weights_md); + } else{ + user_src_md = GetMdMap(md_map, user_src_md_key); + user_weights_md = GetMdMap(md_map, user_weights_md_key); + } /* create memory descriptor for convolution without specified format * ('any') which lets a primitive (convolution in this case) choose @@ -458,53 +476,93 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::shared_ptr conv_pd; auto bias_tz = paddle::framework::vectorize2int(bias->dims()); + + auto src_md_key = key + "@src_md"; + auto weights_md_key = key + "@weights_md_key"; + auto dst_md_key = key + "@dst_md_key"; + auto bias_md_key = key + "@bias_md_key"; + std::shared_ptr src_md; + std::shared_ptr weights_md; + std::shared_ptr dst_md; + if(is_INT8){ - auto src_md = platform::MKLDNNMemDesc( - src_tz, memory::data_type::u8, chosen_memory_format); - auto weights_md = platform::MKLDNNMemDesc( - weights_tz, memory::data_type::s8, - (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); - auto dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char))); - if(fuse_residual_conn){ - auto residual = ctx.Input("ResidualData"); - auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type()); - if(dst_dt != residual_dt) - dst_dt = residual_dt; + if(!md_reuse){ + src_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( + src_tz, memory::data_type::u8, chosen_memory_format))); + weights_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( + weights_tz, memory::data_type::s8, + (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw))); + auto dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char))); + if(fuse_residual_conn){ + auto residual = ctx.Input("ResidualData"); + auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type()); + if(dst_dt != residual_dt) + dst_dt = residual_dt; + } + dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format))); + SetMdMap(md_map, src_md_key, src_md); + SetMdMap(md_map, weights_md_key, weights_md); + SetMdMap(md_map, dst_md_key, dst_md); + } else{ + src_md = GetMdMap(md_map, src_md_key); + weights_md = GetMdMap(md_map, weights_md_key); + dst_md = GetMdMap(md_map, dst_md_key); } - auto dst_md = platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format); // create a conv primitive descriptor and save it for usage in backward if (bias) { - auto bias_md = platform::MKLDNNMemDesc( - bias_tz, memory::data_type::s32, memory::format::x); - conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, + std::shared_ptr bias_md; + if(!md_reuse){ + bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( + bias_tz, memory::data_type::s32, memory::format::x))); + SetMdMap(md_map, bias_md_key, bias_md); + } else{ + bias_md = GetMdMap(md_map, bias_md_key); + } + + conv_pd = ConvFwdPrimitiveDesc(*src_md, *weights_md, *bias_md, *dst_md, strides, paddings, mkldnn_engine, fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale[0], is_test); } else { conv_pd = - ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, + ConvFwdPrimitiveDesc(*src_md, *weights_md, *dst_md, strides, paddings, mkldnn_engine, fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale[0], is_test); } } else{ - auto src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - auto weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), - (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); - auto dst_md = platform::MKLDNNMemDesc( - dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + if(!md_reuse){ + src_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( + src_tz, platform::MKLDNNGetDataType(), chosen_memory_format))); + weights_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( + weights_tz, platform::MKLDNNGetDataType(), + (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw))); + dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( + dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format))); + SetMdMap(md_map, src_md_key, src_md); + SetMdMap(md_map, weights_md_key, weights_md); + SetMdMap(md_map, dst_md_key, dst_md); + } else{ + src_md = GetMdMap(md_map, src_md_key); + weights_md = GetMdMap(md_map, weights_md_key); + dst_md = GetMdMap(md_map, dst_md_key); + } // create a conv primitive descriptor and save it for usage in backward if (bias) { - auto bias_md = platform::MKLDNNMemDesc( - bias_tz, platform::MKLDNNGetDataType(), memory::format::x); - conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, - strides, paddings, mkldnn_engine, - fuse_relu, fuse_residual_conn, is_test); + std::shared_ptr bias_md; + if(!md_reuse){ + bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( + bias_tz, platform::MKLDNNGetDataType(), memory::format::x))); + SetMdMap(md_map, bias_md_key, bias_md); + } else{ + bias_md = GetMdMap(md_map, bias_md_key); + } + conv_pd = ConvFwdPrimitiveDesc(*src_md, *weights_md, *bias_md, *dst_md, + strides, paddings, mkldnn_engine, + fuse_relu, fuse_residual_conn, is_test); } else { - conv_pd = - ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, + conv_pd = + ConvFwdPrimitiveDesc(*src_md, *weights_md, *dst_md, strides, paddings, mkldnn_engine, fuse_relu, fuse_residual_conn, is_test); } } @@ -515,9 +573,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { // create mkldnn memory from input tensors (data/weights) auto user_src_memory_p = - handler.AcquireSrcMemory(user_src_md, to_void_cast(input_data)); + handler.AcquireSrcMemory(*user_src_md, to_void_cast(input_data)); auto user_weights_memory_p = handler.AcquireWeightsMemory( - user_weights_md, to_void_cast(filter_data)); + *user_weights_md, to_void_cast(filter_data)); // create reorder primitive if the input format is not the preferred one auto src_memory_p = @@ -535,6 +593,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::shared_ptr dst_memory_p; bool need_s8_to_u8 = false; + auto user_residual_md_key = key + "@user_residual_md"; if(fuse_residual_conn) { auto residual_param = ctx.Input("ResidualData"); PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), @@ -542,42 +601,48 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { "same dimension sizes"); auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type()); if(residual_param->format() != handler.GetDstFormat()) { - auto residual_data_tz = - paddle::framework::vectorize2int(residual_param->dims()); - auto residual_data_type = - paddle::framework::ToMKLDNNDataType(residual_param->type()); - auto user_residual_md = platform::MKLDNNMemDesc( - residual_data_tz, residual_data_type, residual_param->format()); + std::shared_ptr user_residual_md; + if(!md_reuse){ + auto residual_data_tz = + paddle::framework::vectorize2int(residual_param->dims()); + auto residual_data_type = + paddle::framework::ToMKLDNNDataType(residual_param->type()); + user_residual_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( + residual_data_tz, residual_data_type, residual_param->format()))); + SetMdMap(md_map, user_residual_md_key, user_residual_md); + } else{ + user_residual_md = GetMdMap(md_map, user_residual_md_key); + } if(is_INT8){ if(residual_dt == mkldnn::memory::data_type::u8){ - auto residual_param_data = residual_param->data(); - auto user_residual_memory_p = handler.AcquireResidualDataMemory( - user_residual_md, to_void_cast(residual_param_data)); - PADDLE_ENFORCE( - residual_param_data != nullptr, - "Provide data if you want MKLDNN conv+elementwise_add fusion"); - uint8_t* output_data = output->mutable_data(ctx.GetPlace()); - dst_memory_p = - handler.AcquireDstMemoryFromResidualDataMemory( - user_residual_memory_p, to_void_cast(output_data), pipeline); + auto residual_param_data = residual_param->data(); + auto user_residual_memory_p = handler.AcquireResidualDataMemory( + *user_residual_md, to_void_cast(residual_param_data)); + PADDLE_ENFORCE( + residual_param_data != nullptr, + "Provide data if you want MKLDNN conv+elementwise_add fusion"); + uint8_t* output_data = output->mutable_data(ctx.GetPlace()); + dst_memory_p = + handler.AcquireDstMemoryFromResidualDataMemory( + user_residual_memory_p, to_void_cast(output_data), pipeline); } else{ - auto residual_param_data = residual_param->data(); - auto user_residual_memory_p = handler.AcquireResidualDataMemory( - user_residual_md, to_void_cast(residual_param_data)); - PADDLE_ENFORCE( - residual_param_data != nullptr, - "Provide data if you want MKLDNN conv+elementwise_add fusion"); - int8_t* output_data = output->mutable_data(ctx.GetPlace()); - dst_memory_p = - handler.AcquireDstMemoryFromResidualDataMemory( - user_residual_memory_p, to_void_cast(output_data), pipeline); + auto residual_param_data = residual_param->data(); + auto user_residual_memory_p = handler.AcquireResidualDataMemory( + *user_residual_md, to_void_cast(residual_param_data)); + PADDLE_ENFORCE( + residual_param_data != nullptr, + "Provide data if you want MKLDNN conv+elementwise_add fusion"); + int8_t* output_data = output->mutable_data(ctx.GetPlace()); + dst_memory_p = + handler.AcquireDstMemoryFromResidualDataMemory( + user_residual_memory_p, to_void_cast(output_data), pipeline); if(fuse_relu) need_s8_to_u8 = true; } } else{ auto residual_param_data = residual_param->data(); auto user_residual_memory_p = handler.AcquireResidualDataMemory( - user_residual_md, to_void_cast(residual_param_data)); + *user_residual_md, to_void_cast(residual_param_data)); PADDLE_ENFORCE( residual_param_data != nullptr, "Provide data if you want MKLDNN conv+elementwise_add fusion"); @@ -630,16 +695,23 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::shared_ptr conv_p; std::vector scale_bias_data; auto scale_bias_key = key + "@scale_bias"; + auto user_bias_md_key = key + "@user_bias_md"; if (bias) { const float* bias_data = bias->data(); - auto user_bias_md = platform::MKLDNNMemDesc( - {bias_tz}, platform::MKLDNNGetDataType(), memory::format::x); + std::shared_ptr user_bias_md; + if(!md_reuse){ + user_bias_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( + {bias_tz}, platform::MKLDNNGetDataType(), memory::format::x))); + SetMdMap(md_map, user_bias_md_key, user_bias_md); + } else{ + user_bias_md = GetMdMap(md_map, user_bias_md_key); + } auto user_bias_memory_p = - handler.AcquireBiasMemory(user_bias_md, to_void_cast(bias_data)); + handler.AcquireBiasMemory(*user_bias_md, to_void_cast(bias_data)); std::shared_ptr bias_memory_p; if(is_INT8){ int mask_reorder = is_multi_channel? 1<<0 : 1; - if(scale_reuse){ + if(!scale_reuse){ int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; scale_bias_data.resize(count); #pragma omp parallel for if (count > 1) @@ -689,13 +761,33 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { return; } - std::vector GetScaleMap(std::unordered_map> &scale_map, + std::vector GetScaleMap(std::unordered_map> scale_map, const std::string& name) const { auto it = scale_map.find(name); if (it != scale_map.end()) { return (*it).second; } - return {0}; + return {0.0f}; + } + + void SetMdMap(std::unordered_map> &md_map, + const std::string& name, std::shared_ptr md) const { + auto it = md_map.find(name); + if (it == md_map.end()) { + md_map[name] = md; // create new blob + } else { + (*it).second = md; // set data to existing blob + } + return; + } + + std::shared_ptr GetMdMap(std::unordered_map> md_map, + const std::string& name) const { + auto it = md_map.find(name); + if (it != md_map.end()) { + return (*it).second; + } + return nullptr; } mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn,