From 04b6c29ee01126362c193770ff7d853bbbb26754 Mon Sep 17 00:00:00 2001 From: lidanqing Date: Tue, 28 May 2019 12:07:43 +0200 Subject: [PATCH] Improve mobilenetv2 INT8 performance by using INT8 relu as post-op (#17570) * add INT8 conv+relu6 fuse and enbale mobilentv2 INT8 test test=develop * change fasle and 0.0 to fuse_brelu and brelu_threshold test=develop change the "fuse_relu||fuse_brelu" to "unsigned_output" test=develop * Use relu instead of brelu as INT8 post-op because INT8 brelu is not enabled in mkldnn v0.18 test=develop * continuous-integration fix test=develop --- .../fluid/inference/api/mkldnn_quantizer.cc | 6 ++- .../fluid/inference/tests/api/CMakeLists.txt | 7 +++ .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 46 ++++++++++--------- 3 files changed, 36 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index df9678d693a..9d560ddd2e0 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -64,8 +64,10 @@ bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() { bool is_unsigned = false; if (is_output && op->Type() == "conv2d") { // output of conv2d with relu must be unsigned - is_unsigned = op->HasAttr("fuse_relu") && - boost::get(op->GetAttr("fuse_relu")); + is_unsigned = (op->HasAttr("fuse_relu") && + boost::get(op->GetAttr("fuse_relu"))) || + (op->HasAttr("fuse_brelu") && + boost::get(op->GetAttr("fuse_brelu"))); } else if (is_output && op->Type() == "relu") { is_unsigned = true; } else if (is_output && diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index b37e3936d1b..f96c920d285 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -176,6 +176,13 @@ if(WITH_MKLDNN) endif() inference_analysis_api_int8_test(test_analyzer_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) + #mobilenetv2 int8 + set(INT8_MOBILENETV2_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv2") + if (NOT EXISTS ${INT8_MOBILENETV2_MODEL_DIR}) + inference_download_and_uncompress(${INT8_MOBILENETV2_MODEL_DIR} "${INFERENCE_URL}/int8" "mobilenet_v2_int8_model.tar.gz" ) + endif() + inference_analysis_api_int8_test(test_analyzer_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) + #resnet101 int8 set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101") if (NOT EXISTS ${INT8_RESNET101_MODEL_DIR}) diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 28db85c3ec0..01540e0ef28 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -288,6 +288,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { output->set_layout(DataLayout::kMKLDNN); output->set_format(GetMKLDNNFormat(*dst_memory_p)); } + void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const { const bool is_test = ctx.Attr("is_test"); @@ -325,7 +326,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bool fuse_relu = ctx.Attr("fuse_relu"); bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); bool fuse_brelu = ctx.Attr("fuse_brelu"); + float fuse_brelu_threshold = ctx.Attr("fuse_brelu_threshold"); bool force_fp32_output = ctx.Attr("force_fp32_output"); + bool unsigned_output = fuse_relu || fuse_brelu; if (fuse_residual_conn) { PADDLE_ENFORCE(force_fp32_output != true, "residual fusion does not support force output with fp32"); @@ -340,8 +343,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { "dilation in convolution is not implemented yet"); PADDLE_ENFORCE(is_conv3d != true, "int8 does not support conv3d currently"); - PADDLE_ENFORCE(fuse_brelu != true, - "int8 does not support conv/relu6 fusion currently"); const T* input_data = input->data(); @@ -356,10 +357,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type()); - auto dst_dt = (fuse_relu) ? paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType) - : paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType); + auto dst_dt = unsigned_output + ? paddle::framework::ToMKLDNNDataType( + framework::DataTypeTrait::DataType) + : paddle::framework::ToMKLDNNDataType( + framework::DataTypeTrait::DataType); if (force_fp32_output) { dst_dt = paddle::framework::ToMKLDNNDataType( @@ -377,13 +379,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { key.reserve(MaxKeyLength); platform::ConvMKLDNNHandler::AppendKey( &key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt, - input->format(), fuse_relu, fuse_residual_conn, false /*fuse_brelu*/, + input->format(), fuse_relu, fuse_residual_conn, fuse_brelu, ctx.op().Input("Input") + ctx.op().Input("Filter")); const std::string key_conv_pd = key + "@conv_pd"; bool need_s8_to_u8 = false; - std::shared_ptr conv_p = nullptr; std::shared_ptr src_memory_p = nullptr; std::shared_ptr user_src_memory_p = nullptr; @@ -456,6 +457,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format); // create a conv primitive descriptor and save it for usage in backward + // TODO(lidanqing): We use relu post-op instead of brelu post-op cause + // mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when + // v0.20 is enabled if (bias) { bias_tz = paddle::framework::vectorize2int(bias->dims()); auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32, @@ -463,16 +467,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { conv_pd = ConvFwdPrimitiveDesc( src_md, weights_md, bias_md, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_residual_conn, false /*fuse_brelu*/, - 0.0 /*fuse_brelu_threshold*/, output_shift_scale, sum_scale, - is_test); + mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/, + fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold, + output_shift_scale, sum_scale, is_test); } else { - conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, - paddings, mkldnn_engine, fuse_relu, - fuse_residual_conn, false /*fuse_brelu*/, - 0.0 /*fuse_brelu_threshold*/, - output_shift_scale, sum_scale, is_test); + conv_pd = ConvFwdPrimitiveDesc( + src_md, weights_md, dst_md, strides, paddings, mkldnn_engine, + fuse_relu || fuse_brelu /*fuse_relu*/, fuse_residual_conn, + false /*fuse_brelu*/, fuse_brelu_threshold, output_shift_scale, + sum_scale, is_test); } // Save conv_pd/src_memory/weights_memory for backward pass dev_ctx.SetBlob(key_conv_pd, conv_pd); @@ -514,7 +518,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { ctx, output, residual_param, user_residual_md, handler, &pipeline); } else { - need_s8_to_u8 = fuse_relu; + need_s8_to_u8 = unsigned_output; dst_memory_p = platform::SetDstMemory( ctx, output, residual_param, user_residual_md, handler, &pipeline); @@ -525,12 +529,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { dst_memory_p = platform::SetDstMemory(ctx, output, handler); } else { - need_s8_to_u8 = fuse_relu; + need_s8_to_u8 = unsigned_output; dst_memory_p = platform::SetDstMemory(ctx, output, handler); } } } else if (!force_fp32_output) { - if (fuse_relu) { + if (unsigned_output) { dst_memory_p = platform::SetDstMemory(ctx, output, handler); } else { dst_memory_p = platform::SetDstMemory(ctx, output, handler); @@ -602,12 +606,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { platform::SetDstMemoryHandler(ctx, output, handler, &dst_memory_p); } else { - need_s8_to_u8 = fuse_relu; + need_s8_to_u8 = unsigned_output; platform::SetDstMemoryHandler(ctx, output, handler, &dst_memory_p); } } else if (!force_fp32_output) { - if (fuse_relu) { + if (unsigned_output) { platform::SetDstMemoryHandler(ctx, output, handler, &dst_memory_p); } else { -- GitLab