diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index bd9376bd6cb9c5c1584215101fd46e387d862b54..ce72ec36635696c109bf5ef266a1a4627ea72709 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -173,7 +173,6 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { dev_ctx_.SetBlob(prim_key, conv_p); } else { -std::cout<<"4 is reuse = "<( dev_ctx_.GetBlob(prim_key)); - //is_reusing_ = false; PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false), "Fail to find convolution primitive in device context"); if (conv_p == nullptr) { @@ -197,7 +195,6 @@ std::cout<<"4 is reuse = "< pipeline; -std::cout<<"log1....."<(), mkldnn::memory::format::nChw16c); -std::cout<<"log2....."<(), input->format()); auto user_weights_md = platform::MKLDNNMemDesc( {weights_tz}, platform::MKLDNNGetDataType(), (g == 1) ? filter->format() : mkldnn::memory::format::goihw); -std::cout<<"log3....."<("data_format"); - auto chosen_memory_format = + auto chosen_memory_format = platform::data_format_to_memory_format(data_format); - //std::shared_ptr src_md; - //std::shared_ptr weights_md; - //std::shared_ptr dst_md; - std::vector bias_tz; - - //if(is_INT8){ - // src_md.reset(new platform::MKLDNNMemDesc( - // src_tz, memory::data_type::u8, chosen_memory_format)); - // weights_md.reset(new platform::MKLDNNMemDesc( - // weights_tz, memory::data_type::s8, - // (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw)); - // dst_md.reset(new platform::MKLDNNMemDesc( - // dst_tz, - // fuse_relu?memory::data_type::u8:memory::data_type::s8, - // chosen_memory_format)); - //} else{ - // src_md.reset(new platform::MKLDNNMemDesc( - // src_tz, platform::MKLDNNGetDataType(), chosen_memory_format)); - // weights_md.reset(new platform::MKLDNNMemDesc( - // weights_tz, platform::MKLDNNGetDataType(), - // (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw)); - // dst_md.reset(new platform::MKLDNNMemDesc( - // dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format)); - //} + auto src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); auto weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), + weights_tz, platform::MKLDNNGetDataType(), (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); + std::vector bias_tz; // TODO(mgallus): avoid empty vector creation. + // Currently used whenever bias is != nullptr. auto dst_md = platform::MKLDNNMemDesc( - dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + if(is_INT8){ src_md = platform::MKLDNNMemDesc( src_tz, memory::data_type::u8, chosen_memory_format); weights_md = platform::MKLDNNMemDesc( - weights_tz, memory::data_type::s8, + weights_tz, memory::data_type::s8, (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); dst_md = platform::MKLDNNMemDesc( - dst_tz, + dst_tz, fuse_relu?memory::data_type::u8:memory::data_type::s8, chosen_memory_format); } @@ -494,7 +467,7 @@ std::cout<<"log3....."<1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; std::vector scale_weights_data(count); for(int i=0; idata() + i); + scale_weights_data[i] = *(scale_weights->data() + i); } auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder); @@ -553,8 +526,8 @@ std::cout<<"input fmt = "<format()<<" output fmt = "<format()< {bias_tz}, platform::MKLDNNGetDataType(), memory::format::x); auto user_bias_memory_p = handler.AcquireBiasMemory(user_bias_md, to_void_cast(bias_data)); - std::shared_ptr bias_memory_p;// = - //handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline); + auto bias_memory_p = + handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline); if(is_INT8){ int mask_reorder = is_multi_channel? 0 : 1<<0; int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; @@ -562,12 +535,9 @@ std::cout<<"input fmt = "<format()<<" output fmt = "<format()< for(int i=0; idata()) * (*(scale_weights->data() + i)); } - bias_memory_p = + auto bias_memory_p = handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_INT8, scale_bias_data, mask_reorder); - } else{ - bias_memory_p = - handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline); - } + } conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, bias_memory_p, dst_memory_p); } else { diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 48d6344ed9dc73df85b850b0782415c0ff3416f9..c99966dbcf224de5552fdef8c481533d620eccda 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -70,7 +70,6 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector& dims, mkldnn::memory::data_type data_type, mkldnn::memory::format format) { mkldnn::memory::dims tz = dims; - std::cout<<"this is MKLDNNMemDesc"<<" data_type"<set_data_handle(ptr); // Mark that reusing happenned. All primitives from operator instance // should be reused or none of them. So we check consistency -std::cout<<"1 is reuse = "<set_data_handle(ptr); // Mark that reusing happenned. All primitives from operator instance // should be reused or none of them. So we check consistency -std::cout<<"2 is reuse = "<