提交 04b6c29e 编写于 作者: L lidanqing 提交者: Tao Luo

Improve mobilenetv2 INT8 performance by using INT8 relu as post-op (#17570)

* add INT8 conv+relu6 fuse and enbale mobilentv2 INT8 test
test=develop

* change fasle and 0.0 to fuse_brelu and brelu_threshold
test=develop

change the "fuse_relu||fuse_brelu" to "unsigned_output"
test=develop

* Use relu instead of brelu as INT8 post-op because INT8 brelu is not enabled in mkldnn v0.18
test=develop

* continuous-integration fix
test=develop
上级 6d8075ec
...@@ -64,8 +64,10 @@ bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() { ...@@ -64,8 +64,10 @@ bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() {
bool is_unsigned = false; bool is_unsigned = false;
if (is_output && op->Type() == "conv2d") { if (is_output && op->Type() == "conv2d") {
// output of conv2d with relu must be unsigned // output of conv2d with relu must be unsigned
is_unsigned = op->HasAttr("fuse_relu") && is_unsigned = (op->HasAttr("fuse_relu") &&
boost::get<bool>(op->GetAttr("fuse_relu")); boost::get<bool>(op->GetAttr("fuse_relu"))) ||
(op->HasAttr("fuse_brelu") &&
boost::get<bool>(op->GetAttr("fuse_brelu")));
} else if (is_output && op->Type() == "relu") { } else if (is_output && op->Type() == "relu") {
is_unsigned = true; is_unsigned = true;
} else if (is_output && } else if (is_output &&
......
...@@ -176,6 +176,13 @@ if(WITH_MKLDNN) ...@@ -176,6 +176,13 @@ if(WITH_MKLDNN)
endif() endif()
inference_analysis_api_int8_test(test_analyzer_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) inference_analysis_api_int8_test(test_analyzer_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc)
#mobilenetv2 int8
set(INT8_MOBILENETV2_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv2")
if (NOT EXISTS ${INT8_MOBILENETV2_MODEL_DIR})
inference_download_and_uncompress(${INT8_MOBILENETV2_MODEL_DIR} "${INFERENCE_URL}/int8" "mobilenet_v2_int8_model.tar.gz" )
endif()
inference_analysis_api_int8_test(test_analyzer_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc)
#resnet101 int8 #resnet101 int8
set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101") set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101")
if (NOT EXISTS ${INT8_RESNET101_MODEL_DIR}) if (NOT EXISTS ${INT8_RESNET101_MODEL_DIR})
......
...@@ -288,6 +288,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -288,6 +288,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} }
void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const { void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const {
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
...@@ -325,7 +326,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -325,7 +326,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bool fuse_relu = ctx.Attr<bool>("fuse_relu"); bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection"); bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool fuse_brelu = ctx.Attr<bool>("fuse_brelu"); bool fuse_brelu = ctx.Attr<bool>("fuse_brelu");
float fuse_brelu_threshold = ctx.Attr<float>("fuse_brelu_threshold");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
bool unsigned_output = fuse_relu || fuse_brelu;
if (fuse_residual_conn) { if (fuse_residual_conn) {
PADDLE_ENFORCE(force_fp32_output != true, PADDLE_ENFORCE(force_fp32_output != true,
"residual fusion does not support force output with fp32"); "residual fusion does not support force output with fp32");
...@@ -340,8 +343,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -340,8 +343,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"dilation in convolution is not implemented yet"); "dilation in convolution is not implemented yet");
PADDLE_ENFORCE(is_conv3d != true, "int8 does not support conv3d currently"); PADDLE_ENFORCE(is_conv3d != true, "int8 does not support conv3d currently");
PADDLE_ENFORCE(fuse_brelu != true,
"int8 does not support conv/relu6 fusion currently");
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
...@@ -356,10 +357,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -356,10 +357,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
auto dst_dt = (fuse_relu) ? paddle::framework::ToMKLDNNDataType( auto dst_dt = unsigned_output
framework::DataTypeTrait<uint8_t>::DataType) ? paddle::framework::ToMKLDNNDataType(
: paddle::framework::ToMKLDNNDataType( framework::DataTypeTrait<uint8_t>::DataType)
framework::DataTypeTrait<int8_t>::DataType); : paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<int8_t>::DataType);
if (force_fp32_output) { if (force_fp32_output) {
dst_dt = paddle::framework::ToMKLDNNDataType( dst_dt = paddle::framework::ToMKLDNNDataType(
...@@ -377,13 +379,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -377,13 +379,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
key.reserve(MaxKeyLength); key.reserve(MaxKeyLength);
platform::ConvMKLDNNHandler::AppendKey( platform::ConvMKLDNNHandler::AppendKey(
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt, &key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
input->format(), fuse_relu, fuse_residual_conn, false /*fuse_brelu*/, input->format(), fuse_relu, fuse_residual_conn, fuse_brelu,
ctx.op().Input("Input") + ctx.op().Input("Filter")); ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false; bool need_s8_to_u8 = false;
std::shared_ptr<mkldnn::convolution_forward> conv_p = nullptr; std::shared_ptr<mkldnn::convolution_forward> conv_p = nullptr;
std::shared_ptr<mkldnn::memory> src_memory_p = nullptr; std::shared_ptr<mkldnn::memory> src_memory_p = nullptr;
std::shared_ptr<mkldnn::memory> user_src_memory_p = nullptr; std::shared_ptr<mkldnn::memory> user_src_memory_p = nullptr;
...@@ -456,6 +457,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -456,6 +457,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format); platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward // create a conv primitive descriptor and save it for usage in backward
// TODO(lidanqing): We use relu post-op instead of brelu post-op cause
// mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when
// v0.20 is enabled
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32, auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32,
...@@ -463,16 +467,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -463,16 +467,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
conv_pd = ConvFwdPrimitiveDesc( conv_pd = ConvFwdPrimitiveDesc(
src_md, weights_md, bias_md, dst_md, strides, paddings, src_md, weights_md, bias_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn, false /*fuse_brelu*/, mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
0.0 /*fuse_brelu_threshold*/, output_shift_scale, sum_scale, fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
is_test); output_shift_scale, sum_scale, is_test);
} else { } else {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, conv_pd = ConvFwdPrimitiveDesc(
paddings, mkldnn_engine, fuse_relu, src_md, weights_md, dst_md, strides, paddings, mkldnn_engine,
fuse_residual_conn, false /*fuse_brelu*/, fuse_relu || fuse_brelu /*fuse_relu*/, fuse_residual_conn,
0.0 /*fuse_brelu_threshold*/, false /*fuse_brelu*/, fuse_brelu_threshold, output_shift_scale,
output_shift_scale, sum_scale, is_test); sum_scale, is_test);
} }
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
...@@ -514,7 +518,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -514,7 +518,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
ctx, output, residual_param, user_residual_md, handler, ctx, output, residual_param, user_residual_md, handler,
&pipeline); &pipeline);
} else { } else {
need_s8_to_u8 = fuse_relu; need_s8_to_u8 = unsigned_output;
dst_memory_p = platform::SetDstMemory<int8_t>( dst_memory_p = platform::SetDstMemory<int8_t>(
ctx, output, residual_param, user_residual_md, handler, ctx, output, residual_param, user_residual_md, handler,
&pipeline); &pipeline);
...@@ -525,12 +529,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -525,12 +529,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_memory_p = dst_memory_p =
platform::SetDstMemory<uint8_t>(ctx, output, handler); platform::SetDstMemory<uint8_t>(ctx, output, handler);
} else { } else {
need_s8_to_u8 = fuse_relu; need_s8_to_u8 = unsigned_output;
dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler); dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler);
} }
} }
} else if (!force_fp32_output) { } else if (!force_fp32_output) {
if (fuse_relu) { if (unsigned_output) {
dst_memory_p = platform::SetDstMemory<uint8_t>(ctx, output, handler); dst_memory_p = platform::SetDstMemory<uint8_t>(ctx, output, handler);
} else { } else {
dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler); dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler);
...@@ -602,12 +606,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -602,12 +606,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler, platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler,
&dst_memory_p); &dst_memory_p);
} else { } else {
need_s8_to_u8 = fuse_relu; need_s8_to_u8 = unsigned_output;
platform::SetDstMemoryHandler<int8_t>(ctx, output, handler, platform::SetDstMemoryHandler<int8_t>(ctx, output, handler,
&dst_memory_p); &dst_memory_p);
} }
} else if (!force_fp32_output) { } else if (!force_fp32_output) {
if (fuse_relu) { if (unsigned_output) {
platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler, platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler,
&dst_memory_p); &dst_memory_p);
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册