diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 10d3b6ce4f79e7c59d7d3588b3d481d01ef04c46..647e09a92911e327ba01b7bb23fdb617f949cea4 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -383,14 +383,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const std::string key_conv_pd = key + "@conv_pd"; bool need_s8_to_u8 = false; - std::shared_ptr conv_p = nullptr; - std::shared_ptr src_memory_p = nullptr; - std::shared_ptr user_src_memory_p = nullptr; - std::shared_ptr dst_memory_p = nullptr; + std::shared_ptr conv_p; + std::shared_ptr src_memory_p; + std::shared_ptr user_src_memory_p; + std::shared_ptr dst_memory_p; std::vector pipeline; - std::shared_ptr conv_pd = - nullptr; - std::shared_ptr handler = nullptr; + std::shared_ptr conv_pd; + std::shared_ptr handler; auto prim_key = key + "@conv_p"; auto dst_key = key + "@dst_mem_p"; @@ -460,24 +459,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { // TODO(lidanqing): We use relu post-op instead of brelu post-op cause // mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when // v0.20 is enabled + std::shared_ptr bias_md_p; if (bias) { bias_tz = paddle::framework::vectorize2int(bias->dims()); - auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32, - memory::format::x); - - conv_pd = ConvFwdPrimitiveDesc( - src_md, weights_md, bias_md, dst_md, strides, paddings, - mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/, - fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold, - output_shift_scale, sum_scale, is_test); - - } else { - conv_pd = ConvFwdPrimitiveDesc( - src_md, weights_md, dst_md, strides, paddings, mkldnn_engine, - fuse_relu || fuse_brelu /*fuse_relu*/, fuse_residual_conn, - false /*fuse_brelu*/, fuse_brelu_threshold, output_shift_scale, - sum_scale, is_test); + bias_md_p = std::make_shared(platform::MKLDNNMemDesc( + bias_tz, memory::data_type::s32, memory::format::x)); } + conv_pd = ConvFwdPrimitiveDesc( + src_md, weights_md, bias_md_p, dst_md, strides, paddings, + mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/, + fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold, + output_shift_scale, sum_scale, is_test); // Save conv_pd/src_memory/weights_memory for backward pass dev_ctx.SetBlob(key_conv_pd, conv_pd); handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, @@ -649,7 +641,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { private: mkldnn::primitive_attr CreatePostOps( bool fuse_relu, bool fuse_residual_conn, - const std::vector output_shift_scale, float sum_scale, + const std::vector& output_shift_scale, float sum_scale, bool fuse_brelu, float fuse_brelu_threshold) const { mkldnn::primitive_attr conv_attr; mkldnn::post_ops post_operations; @@ -679,52 +671,29 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::unique_ptr ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, + const std::shared_ptr bias_md_p, const memory::desc& dst, const std::vector& strides, const std::vector& paddings, const mkldnn::engine& engine, const bool fuse_relu, const bool fuse_residual_conn, const bool fuse_brelu, const float fuse_brelu_threshold, - const std::vector output_shift_scale, + const std::vector& output_shift_scale, const float sum_scale, bool is_test) const { memory::dims stride_dims = {strides[0], strides[1]}; memory::dims padding_dims = {paddings[0], paddings[1]}; auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training; - - auto conv_desc = mkldnn::convolution_forward::desc( - propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims, - padding_dims, padding_dims, mkldnn::padding_kind::zero); - mkldnn::primitive_attr conv_attr = - CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale, - sum_scale, fuse_brelu, fuse_brelu_threshold); - - auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( - conv_desc, conv_attr, engine); - - return std::unique_ptr( - p_conv_pd); - } - - std::unique_ptr - ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, - const memory::desc& bias, const memory::desc& dst, - const std::vector& strides, - const std::vector& paddings, - const mkldnn::engine& engine, const bool fuse_relu, - const bool fuse_residual_conn, const bool fuse_brelu, - const float fuse_brelu_threshold, - const std::vector output_shift_scale, - const float sum_scale, bool is_test) const { - memory::dims stride_dims = {strides[0], strides[1]}; - memory::dims padding_dims = {paddings[0], paddings[1]}; - - auto propagation = is_test ? mkldnn::prop_kind::forward_scoring - : mkldnn::prop_kind::forward_training; - - auto conv_desc = mkldnn::convolution_forward::desc( - propagation, mkldnn::convolution_direct, src, weights, bias, dst, - stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); + auto conv_desc = + (bias_md_p != nullptr) + ? mkldnn::convolution_forward::desc( + propagation, mkldnn::convolution_direct, src, weights, + (*bias_md_p), dst, stride_dims, padding_dims, padding_dims, + mkldnn::padding_kind::zero) + : mkldnn::convolution_forward::desc( + propagation, mkldnn::convolution_direct, src, weights, dst, + stride_dims, padding_dims, padding_dims, + mkldnn::padding_kind::zero); mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale,