diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 31fcadb2066fbc969573ec442e99027c7e8396da..2900856754670df51cc6a7f3db2e42d577478215 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -480,12 +480,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, fuse_relu, fuse_residual_conn, - output_shift_scale, sum_scale[0]); + output_shift_scale, sum_scale[0], is_test); } else { conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, mkldnn_engine, fuse_relu, fuse_residual_conn, - output_shift_scale, sum_scale[0]); + output_shift_scale, sum_scale[0], is_test); } } else{ auto src_md = platform::MKLDNNMemDesc( @@ -501,11 +501,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz, platform::MKLDNNGetDataType(), memory::format::x); conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, - fuse_relu, fuse_residual_conn); + fuse_relu, fuse_residual_conn, is_test); } else { conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_residual_conn); + mkldnn_engine, fuse_relu, fuse_residual_conn, is_test); } } // Save conv_pd/src_memory/weights_memory for backward pass @@ -743,12 +743,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const std::vector& paddings, const mkldnn::engine& engine, const bool fuse_relu, const bool fuse_residual_conn, - const std::vector output_shift_scale, const float sum_scale) const { + const std::vector output_shift_scale, const float sum_scale, bool is_test) const { memory::dims stride_dims = {strides[0], strides[1]}; memory::dims padding_dims = {paddings[0], paddings[1]}; + auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training; + auto conv_desc = mkldnn::convolution_forward::desc( - mkldnn::prop_kind::forward_scoring, mkldnn::convolution_direct, src, weights, + propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); @@ -767,12 +769,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const memory::desc& dst, const std::vector& strides, const std::vector& paddings, const mkldnn::engine& engine, const bool fuse_relu, - const bool fuse_residual_conn) const{ + const bool fuse_residual_conn, bool is_test) const{ memory::dims stride_dims = {strides[0], strides[1]}; memory::dims padding_dims = {paddings[0], paddings[1]}; - + + auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training; + auto conv_desc = mkldnn::convolution_forward::desc( - mkldnn::prop_kind::forward_scoring, mkldnn::convolution_direct, src, weights, + propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); @@ -792,12 +796,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const std::vector& paddings, const mkldnn::engine& engine, const bool fuse_relu, const bool fuse_residual_conn, - const std::vector output_shift_scale, const float sum_scale) const { + const std::vector output_shift_scale, const float sum_scale, bool is_test) const { memory::dims stride_dims = {strides[0], strides[1]}; memory::dims padding_dims = {paddings[0], paddings[1]}; + auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training; + auto conv_desc = mkldnn::convolution_forward::desc( - mkldnn::prop_kind::forward_scoring, mkldnn::convolution_direct, src, weights, + propagation, mkldnn::convolution_direct, src, weights, bias, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); @@ -817,12 +823,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const std::vector& strides, const std::vector& paddings, const mkldnn::engine& engine, const bool fuse_relu, - const bool fuse_residual_conn) const{ + const bool fuse_residual_conn, bool is_test) const{ memory::dims stride_dims = {strides[0], strides[1]}; memory::dims padding_dims = {paddings[0], paddings[1]}; + auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training; + auto conv_desc = mkldnn::convolution_forward::desc( - mkldnn::prop_kind::forward_scoring, mkldnn::convolution_direct, src, weights, + propagation, mkldnn::convolution_direct, src, weights, bias, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);