From 9ecd8ee789b811e69e773a45e8258cb13e6cd8f8 Mon Sep 17 00:00:00 2001 From: lidanqing Date: Thu, 25 Jul 2019 05:02:10 +0200 Subject: [PATCH] change ComputeINT8 to template version to remove checking dst_datatype code (#18756) * change INT8 to template so that checking dst_dt with if-else could be removed. CI will be enabled after fixing reviews * reverse user_residual_memory_p and user_bias_memory_p declaration scope test=develop --- .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 228 ++++++------------ paddle/fluid/platform/mkldnn_helper.h | 15 +- paddle/fluid/platform/mkldnn_reuse.h | 27 ++- 3 files changed, 98 insertions(+), 172 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 876a0b8b60..8aeb1264ce 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -69,6 +69,26 @@ inline mkldnn::memory::format GetWeightsFormat(mkldnn::memory::format format, } } +static mkldnn::memory::data_type GetDstType(bool is_int8, + bool force_fp32_output, + bool fuse_relu, bool fuse_brelu, + bool fuse_residual_conn, + const Tensor* residual_param) { + auto dst_dt = mkldnn::memory::data_type::f32; // uint8_t, int8_t, float + if (is_int8) { + dst_dt = (fuse_relu || fuse_brelu) ? mkldnn::memory::data_type::u8 + : mkldnn::memory::data_type::s8; + if (force_fp32_output) { + dst_dt = mkldnn::memory::data_type::f32; + } + if (fuse_residual_conn && residual_param) { + auto residual_dt = framework::ToMKLDNNDataType(residual_param->type()); + if (dst_dt != residual_dt) dst_dt = residual_dt; + } + } + return dst_dt; +} + template class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { public: @@ -80,7 +100,20 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { if (!is_INT8) { ComputeFP32(ctx); } else { - ComputeINT8(ctx); + bool fuse_relu = ctx.Attr("fuse_relu"); + bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); + bool fuse_brelu = ctx.Attr("fuse_brelu"); + bool force_fp32_output = ctx.Attr("force_fp32_output"); + auto residual_param = ctx.Input("ResidualData"); + auto dst_dt = GetDstType(true, force_fp32_output, fuse_relu, fuse_brelu, + fuse_residual_conn, residual_param); + if (dst_dt == mkldnn::memory::data_type::f32) { + ComputeINT8(ctx); + } else if (dst_dt == mkldnn::memory::data_type::u8) { + ComputeINT8(ctx); + } else if (dst_dt == mkldnn::memory::data_type::s8) { + ComputeINT8(ctx); + } } } @@ -287,7 +320,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { output->set_layout(DataLayout::kMKLDNN); output->set_format(GetMKLDNNFormat(*dst_memory_p)); } - + template void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const { const bool is_test = ctx.Attr("is_test"); @@ -328,10 +361,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { float fuse_brelu_threshold = ctx.Attr("fuse_brelu_threshold"); bool force_fp32_output = ctx.Attr("force_fp32_output"); bool unsigned_output = fuse_relu || fuse_brelu; - if (fuse_residual_conn) { - PADDLE_ENFORCE(force_fp32_output != true, - "residual fusion does not support force output with fp32"); - } + + PADDLE_ENFORCE(!fuse_residual_conn || !force_fp32_output, + "residual fusion does not support force output with fp32"); + bool is_conv3d = strides.size() == 3U; // TODO(tpatejko): add support for dilation PADDLE_ENFORCE( @@ -356,23 +389,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type()); - auto dst_dt = unsigned_output - ? paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType()) - : paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType()); - - if (force_fp32_output) { - dst_dt = paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType()); - } - - if (fuse_residual_conn) { - auto residual = ctx.Input("ResidualData"); - auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type()); - if (dst_dt != residual_dt) dst_dt = residual_dt; - } - // Get unique name for storing MKLDNN primitives std::string key; key.reserve(MaxKeyLength); @@ -453,28 +469,35 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format); auto weights_md = platform::MKLDNNMemDesc( weights_tz, memory::data_type::s8, chosen_memory_format); - auto dst_md = - platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format); + auto dst_md = platform::MKLDNNMemDesc( + dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + handler.reset( + new platform::ConvMKLDNNHandler(dev_ctx, mkldnn_engine, key)); // create a conv primitive descriptor and save it for usage in backward // TODO(lidanqing): We use relu post-op instead of brelu post-op cause // mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when // v0.20 is enabled - std::shared_ptr bias_md_p; + auto propagation = is_test ? mkldnn::prop_kind::forward_scoring + : mkldnn::prop_kind::forward_training; + if (bias) { bias_tz = paddle::framework::vectorize2int(bias->dims()); - bias_md_p = std::make_shared(platform::MKLDNNMemDesc( - bias_tz, memory::data_type::s32, memory::format::x)); + auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32, + mkldnn::memory::format::x); + conv_pd = handler->AcquireConvolutionPrimitiveDescriptor( + src_md, weights_md, bias_md, dst_md, strides, paddings, + mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/, + fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold, + propagation, output_shift_scale, sum_scale); + } else { + conv_pd = handler->AcquireConvolutionPrimitiveDescriptor( + src_md, weights_md, boost::none, dst_md, strides, paddings, + mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/, + fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold, + propagation, output_shift_scale, sum_scale); } - conv_pd = ConvFwdPrimitiveDesc( - src_md, weights_md, bias_md_p, dst_md, strides, paddings, - mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/, - fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold, - output_shift_scale, sum_scale, is_test); - // Save conv_pd/src_memory/weights_memory for backward pass - dev_ctx.SetBlob(key_conv_pd, conv_pd); - handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, - mkldnn_engine, key)); + // create mkldnn memory from input tensors (data/weights) user_src_memory_p = handler->AcquireSrcMemory(user_src_md, to_void_cast(input_data)); @@ -502,38 +525,20 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { if (residual_param->format() != handler->GetDstFormat()) { auto residual_data_tz = paddle::framework::vectorize2int(residual_param->dims()); - auto user_residual_md = platform::MKLDNNMemDesc( residual_data_tz, residual_dt, residual_param->format()); - - if (residual_dt == mkldnn::memory::data_type::u8) { - dst_memory_p = platform::SetDstMemory( - ctx, output, residual_param, user_residual_md, handler, - &pipeline); - } else { - need_s8_to_u8 = unsigned_output; - dst_memory_p = platform::SetDstMemory( - ctx, output, residual_param, user_residual_md, handler, - &pipeline); - } + dst_memory_p = platform::SetDstMemory( + ctx, output, residual_param, user_residual_md, handler, + &pipeline); } else { output->ShareDataWith(*residual_param); - if (residual_dt == mkldnn::memory::data_type::u8) { - dst_memory_p = - platform::SetDstMemory(ctx, output, handler); - } else { - need_s8_to_u8 = unsigned_output; - dst_memory_p = platform::SetDstMemory(ctx, output, handler); - } - } - } else if (!force_fp32_output) { - if (unsigned_output) { - dst_memory_p = platform::SetDstMemory(ctx, output, handler); - } else { - dst_memory_p = platform::SetDstMemory(ctx, output, handler); + dst_memory_p = platform::SetDstMemory(ctx, output, handler); } + need_s8_to_u8 = + (platform::MKLDNNGetDataType() == memory::data_type::s8) && + unsigned_output; } else { - dst_memory_p = platform::SetDstMemory(ctx, output, handler); + dst_memory_p = platform::SetDstMemory(ctx, output, handler); } // create convolution op primitive @@ -564,7 +569,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p, dst_memory_p); } - // push primitive to stream and wait until it's executed pipeline.push_back(*conv_p); } else { @@ -592,29 +596,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { if (fuse_residual_conn) { auto residual_param = ctx.Input("ResidualData"); - auto residual_dt = - paddle::framework::ToMKLDNNDataType(residual_param->type()); output->ShareDataWith(*residual_param); - if (residual_dt == mkldnn::memory::data_type::u8) { - platform::SetDstMemoryHandler(ctx, output, handler, - &dst_memory_p); - } else { - need_s8_to_u8 = unsigned_output; - platform::SetDstMemoryHandler(ctx, output, handler, - &dst_memory_p); - } - } else if (!force_fp32_output) { - if (unsigned_output) { - platform::SetDstMemoryHandler(ctx, output, handler, - &dst_memory_p); - } else { - platform::SetDstMemoryHandler(ctx, output, handler, - &dst_memory_p); - } - } else { - platform::SetDstMemoryHandler(ctx, output, handler, - &dst_memory_p); + need_s8_to_u8 = + (platform::MKLDNNGetDataType() == memory::data_type::s8) && + unsigned_output; } + platform::SetDstMemoryHandler(ctx, output, handler, dst_memory_p); if (src_memory_reorder_p) { pipeline.push_back(*src_memory_reorder_p); @@ -625,87 +612,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { if (residual_reorder_p) { pipeline.push_back(*residual_reorder_p); } - pipeline.push_back(*conv_p); } // push primitive to stream and wait until it's executed stream(stream::kind::eager).submit(pipeline).wait(); - if (need_s8_to_u8) { output->mutable_data(ctx.GetPlace()); } - output->set_layout(DataLayout::kMKLDNN); output->set_format(GetMKLDNNFormat(*dst_memory_p)); } - - private: - mkldnn::primitive_attr CreatePostOps( - bool fuse_relu, bool fuse_residual_conn, - const std::vector& output_shift_scale, float sum_scale, - bool fuse_brelu, float fuse_brelu_threshold) const { - mkldnn::primitive_attr conv_attr; - mkldnn::post_ops post_operations; - int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0; - conv_attr.set_output_scales(mask, output_shift_scale); - - if (fuse_residual_conn) { - post_operations.append_sum(sum_scale); - } - if (fuse_relu) { - constexpr float scale = 1.0f; - constexpr float negative_slope = 0.0f; - constexpr float placeholder = 1.0f; // beta - post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, - negative_slope, placeholder); - } - if (fuse_brelu) { - constexpr float scale = 1.0f; - constexpr float placeholder = 0.0f; // beta - post_operations.append_eltwise(scale, - mkldnn::algorithm::eltwise_bounded_relu, - fuse_brelu_threshold, placeholder); - } - conv_attr.set_post_ops(post_operations); - return conv_attr; - } - - std::unique_ptr - ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, - const std::shared_ptr bias_md_p, - const memory::desc& dst, const std::vector& strides, - const std::vector& paddings, - const mkldnn::engine& engine, const bool fuse_relu, - const bool fuse_residual_conn, const bool fuse_brelu, - const float fuse_brelu_threshold, - const std::vector& output_shift_scale, - const float sum_scale, bool is_test) const { - memory::dims stride_dims = {strides[0], strides[1]}; - memory::dims padding_dims = {paddings[0], paddings[1]}; - - auto propagation = is_test ? mkldnn::prop_kind::forward_scoring - : mkldnn::prop_kind::forward_training; - auto conv_desc = - (bias_md_p != nullptr) - ? mkldnn::convolution_forward::desc( - propagation, mkldnn::convolution_direct, src, weights, - (*bias_md_p), dst, stride_dims, padding_dims, padding_dims, - mkldnn::padding_kind::zero) - : mkldnn::convolution_forward::desc( - propagation, mkldnn::convolution_direct, src, weights, dst, - stride_dims, padding_dims, padding_dims, - mkldnn::padding_kind::zero); - - mkldnn::primitive_attr conv_attr = - CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale, - sum_scale, fuse_brelu, fuse_brelu_threshold); - - auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( - conv_desc, conv_attr, engine); - - return std::unique_ptr( - p_conv_pd); - } }; template diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index dafdb4eab9..8bcb8acee9 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -20,7 +20,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/place.h" - namespace paddle { namespace platform { @@ -82,22 +81,24 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { template mkldnn::memory::data_type MKLDNNGetDataType() { - return mkldnn::memory::data_undef; + return mkldnn::memory::data_type::data_undef; } template <> inline mkldnn::memory::data_type MKLDNNGetDataType() { - return mkldnn::memory::f32; + return mkldnn::memory::data_type::f32; +} +template <> +inline mkldnn::memory::data_type MKLDNNGetDataType() { + return mkldnn::memory::data_type::s32; } - template <> inline mkldnn::memory::data_type MKLDNNGetDataType() { - return mkldnn::memory::s8; + return mkldnn::memory::data_type::s8; } - template <> inline mkldnn::memory::data_type MKLDNNGetDataType() { - return mkldnn::memory::u8; + return mkldnn::memory::data_type::u8; } inline void Reorder(const mkldnn::memory& src, const mkldnn::memory& dst) { diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 9f277d682b..eb25a4e046 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -1160,18 +1160,24 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { scale_data, mask); } - mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn, - bool fuse_brelu, - float fuse_brelu_threshold) const { + mkldnn::primitive_attr CreatePostOps( + bool fuse_relu, bool fuse_residual_conn, bool fuse_brelu, + float fuse_brelu_threshold, + const std::vector output_shift_scale = {}, + float sum_scale = 1.0f) const { mkldnn::primitive_attr conv_attr; mkldnn::post_ops post_operations; + if (output_shift_scale.size() > 0) { + int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0; + conv_attr.set_output_scales(mask, output_shift_scale); + } // Fusion with Elementwise layer relies on adding a sum post-operation with // the scale parameter. It is assumed that when fuse_residual_connection is // true, the output tensor contains the data coming from residual // connection. The result of this post_op is: // Output = scale * Output + Conv_Out. if (fuse_residual_conn) { - post_operations.append_sum(1.0f); + post_operations.append_sum(sum_scale); } // Fusion with ReLU layer is executed through the PostOps feature. Create a // PostOps object and configure it to execute an eltwise relu operation. @@ -1202,7 +1208,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { const std::vector& paddings, const mkldnn::engine& engine, const bool fuse_relu, const bool fuse_residual_conn, const bool fuse_brelu, const float fuse_brelu_threshold, - mkldnn::prop_kind fwd_prop_kind) { + mkldnn::prop_kind fwd_prop_kind, + const std::vector output_shift_scale = {}, + const float sum_scale = 1.0f) { // Conv PD has to be passed to Grad op that // may be exxecuted by diffrent thread, hence // for that one we use key that does not contain TID @@ -1232,8 +1240,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { src, weights, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - mkldnn::primitive_attr conv_attr = CreatePostOps( - fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold); + mkldnn::primitive_attr conv_attr = + CreatePostOps(fuse_relu, fuse_residual_conn, fuse_brelu, + fuse_brelu_threshold, output_shift_scale, sum_scale); conv_pd_.reset(new typename forward_t::primitive_desc( conv_desc, conv_attr, engine)); @@ -1393,10 +1402,10 @@ template static void SetDstMemoryHandler( const framework::ExecutionContext& ctx, framework::Tensor* output, const std::shared_ptr& handler, - std::shared_ptr* dst_memory_p) { + std::shared_ptr dst_memory_p) { T* output_data = output->mutable_data(ctx.GetPlace(), handler->GetDstMemorySize()); - (*dst_memory_p)->set_data_handle(to_void_cast(output_data)); + dst_memory_p->set_data_handle(to_void_cast(output_data)); } template -- GitLab