diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 0a023d8b215cc6f7dc6b00a8c03c021e195ef0bd..31fcadb2066fbc969573ec442e99027c7e8396da 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -401,11 +401,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; scale_in_data = {*(scale_in->data())}; scale_weights_data.resize(count); + #pragma omp parallel for if (count > 1) for(int i=0; idata() + i); } scale_out_data = {*(scale_out->data())}; output_shift_scale.resize(count); + #pragma omp parallel for if (count > 1) for(int i=0; i { auto chosen_memory_format = platform::data_format_to_memory_format(data_format); - auto src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - auto weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), - (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); - auto dst_md = platform::MKLDNNMemDesc( - dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - std::vector bias_tz; - + std::shared_ptr conv_pd; + auto bias_tz = paddle::framework::vectorize2int(bias->dims()); if(is_INT8){ - src_md = platform::MKLDNNMemDesc( + auto src_md = platform::MKLDNNMemDesc( src_tz, memory::data_type::u8, chosen_memory_format); - weights_md = platform::MKLDNNMemDesc( + auto weights_md = platform::MKLDNNMemDesc( weights_tz, memory::data_type::s8, (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); auto dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char))); if(fuse_residual_conn){ auto residual = ctx.Input("ResidualData"); auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type()); - if(dst_dt != residual_dt) + if(dst_dt != residual_dt) dst_dt = residual_dt; } - dst_md = platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format); - } + auto dst_md = platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format); - // create a conv primitive descriptor and save it for usage in backward - std::shared_ptr conv_pd; - if (bias) { - bias_tz = paddle::framework::vectorize2int(bias->dims()); - auto bias_md = platform::MKLDNNMemDesc( - bias_tz, platform::MKLDNNGetDataType(), memory::format::x); - if(is_INT8){ - bias_md = platform::MKLDNNMemDesc( + // create a conv primitive descriptor and save it for usage in backward + if (bias) { + auto bias_md = platform::MKLDNNMemDesc( bias_tz, memory::data_type::s32, memory::format::x); - } - if(is_INT8){ conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, - fuse_relu, fuse_residual_conn, + fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale[0]); - } else{ - conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, - strides, paddings, mkldnn_engine, - fuse_relu, fuse_residual_conn); - } - } else { - if(is_INT8){ + } else { conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, mkldnn_engine, fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale[0]); - } else{ - conv_pd = - ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_residual_conn); + } + } else{ + auto src_md = platform::MKLDNNMemDesc( + src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + auto weights_md = platform::MKLDNNMemDesc( + weights_tz, platform::MKLDNNGetDataType(), + (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); + auto dst_md = platform::MKLDNNMemDesc( + dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + // create a conv primitive descriptor and save it for usage in backward + if (bias) { + auto bias_md = platform::MKLDNNMemDesc( + bias_tz, platform::MKLDNNGetDataType(), memory::format::x); + conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, + strides, paddings, mkldnn_engine, + fuse_relu, fuse_residual_conn); + } else { + conv_pd = + ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, + mkldnn_engine, fuse_relu, fuse_residual_conn); } } // Save conv_pd/src_memory/weights_memory for backward pass @@ -634,6 +631,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { if(scale_reuse){ int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; scale_bias_data.resize(count); + #pragma omp parallel for if (count > 1) for(int i=0; i { // the scale parameter. It is assumed that when fuse_residual_conn is true, the // Output tensor contains the data coming from residual connection. The // result of this post_op is: Output = scale * Output + Conv_Out. - + conv_attr.set_output_scales(0, {1.0f}); if (fuse_residual_conn) { post_operations.append_sum(1.0f); } @@ -774,7 +772,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { memory::dims padding_dims = {paddings[0], paddings[1]}; auto conv_desc = mkldnn::convolution_forward::desc( - mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, + mkldnn::prop_kind::forward_scoring, mkldnn::convolution_direct, src, weights, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); @@ -824,7 +822,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { memory::dims padding_dims = {paddings[0], paddings[1]}; auto conv_desc = mkldnn::convolution_forward::desc( - mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, + mkldnn::prop_kind::forward_scoring, mkldnn::convolution_direct, src, weights, bias, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);