diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 43aba94dd2b45c763dfa6681b13a3852dc4be325..11b7efbe00cf80c2a5988fd8da1b8888ec617002 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -131,17 +131,24 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { std::shared_ptr AcquireWeightsMemoryFromPrimitive( const std::shared_ptr user_weights_memory_p, std::vector& pipeline, // NOLINT - bool is_persistent = false) { + bool is_persistent = false, + bool is_INT8 = false, + std::vector scale_data = {1.0f}, + int mask = 0) { auto user_weights_pd = user_weights_memory_p->get_primitive_desc(); auto weights_pd = conv_pd_->weights_primitive_desc(); return this->AcquireMemory(weights_pd, user_weights_pd, user_weights_memory_p, "@weights_mem_p", - pipeline, is_persistent); + pipeline, is_persistent, + is_INT8, scale_data, mask); } std::shared_ptr AcquireBiasMemoryFromPrimitive( const std::shared_ptr user_bias_memory_p, - std::vector& pipeline) { // NOLINT + std::vector& pipeline, + bool is_INT8 = false, + std::vector scale_data = {1.0f}, + int mask = 0) { // NOLINT auto user_bias_pd = user_bias_memory_p->get_primitive_desc(); auto bias_pd = conv_pd_->bias_primitive_desc(); return this->AcquireMemory(bias_pd, user_bias_pd, user_bias_memory_p, @@ -283,6 +290,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto* scale_in_eltwise = ctx.HasInput("Scale_in_eltwise")? ctx.Input("Scale_in_eltwise") : nullptr; auto* scale_weights = ctx.HasInput("Scale_weights")? ctx.Input("Scale_weights") : nullptr; auto* scale_out = ctx.HasInput("Scale_out")? ctx.Input("Scale_out") : nullptr; + bool is_multi_channel = (is_INT8 && scale_weights->memory_size() > 1) ? true : false; PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN && input->format() != memory::format::format_undef, @@ -338,12 +346,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector output_shift_scale; T sum_scale = 1.0f; if(is_INT8){ - int count = g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]; + int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; T scale_in_data = *(scale_in->data()); T scale_in_eltwise_data = *(scale_in_eltwise->data()); std::vector scale_weights_data(count); for(int i=0; idata()); + scale_weights_data[i] =*(scale_weights->data() + i); } T scale_out_data = *(scale_out->data()); @@ -436,6 +444,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( user_weights_memory_p, pipeline, is_test); + if(is_INT8){ + int mask_reorder = is_multi_channel? 0 : ((g!= 1) ? (1<<1)+(1<<0) : 1<<0); + int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; + std::vector scale_weights_data(count); + for(int i=0; idata() + i); + } + auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( + user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder); + } auto dst_memory_p = handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); @@ -447,9 +465,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { {bias_tz}, platform::MKLDNNGetDataType(), memory::format::x); auto user_bias_memory_p = handler.AcquireBiasMemory(user_bias_md, to_void_cast(bias_data)); - auto bias_memory_p = handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline); + if(is_INT8){ + int mask_reorder = is_multi_channel? 0 : 1<<0; + int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; + std::vector scale_bias_data(count); + for(int i=0; idata()) * (*(scale_weights->data() + i)); + } + auto bias_memory_p = + handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_INT8, scale_bias_data, mask_reorder); + } conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, bias_memory_p, dst_memory_p); } else { @@ -470,7 +497,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const std::vector output_shift_scale, T sum_scale) const { mkldnn::primitive_attr conv_attr; mkldnn::post_ops post_operations; - int mask = 0; + int mask = output_shift_scale.size() > 1 ? 1<<1 : 0; conv_attr.set_output_scales(mask, output_shift_scale); if (fuse_eltwise) { post_operations.append_sum(sum_scale);