提交 04b6c29e 编写于 作者: L lidanqing 提交者: Tao Luo

Improve mobilenetv2 INT8 performance by using INT8 relu as post-op (#17570)

* add INT8 conv+relu6 fuse and enbale mobilentv2 INT8 test
test=develop

* change fasle and 0.0 to fuse_brelu and brelu_threshold
test=develop

change the "fuse_relu||fuse_brelu" to "unsigned_output"
test=develop

* Use relu instead of brelu as INT8 post-op because INT8 brelu is not enabled in mkldnn v0.18
test=develop

* continuous-integration fix
test=develop
上级 6d8075ec
......@@ -64,8 +64,10 @@ bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() {
bool is_unsigned = false;
if (is_output && op->Type() == "conv2d") {
// output of conv2d with relu must be unsigned
is_unsigned = op->HasAttr("fuse_relu") &&
boost::get<bool>(op->GetAttr("fuse_relu"));
is_unsigned = (op->HasAttr("fuse_relu") &&
boost::get<bool>(op->GetAttr("fuse_relu"))) ||
(op->HasAttr("fuse_brelu") &&
boost::get<bool>(op->GetAttr("fuse_brelu")));
} else if (is_output && op->Type() == "relu") {
is_unsigned = true;
} else if (is_output &&
......
......@@ -176,6 +176,13 @@ if(WITH_MKLDNN)
endif()
inference_analysis_api_int8_test(test_analyzer_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc)
#mobilenetv2 int8
set(INT8_MOBILENETV2_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv2")
if (NOT EXISTS ${INT8_MOBILENETV2_MODEL_DIR})
inference_download_and_uncompress(${INT8_MOBILENETV2_MODEL_DIR} "${INFERENCE_URL}/int8" "mobilenet_v2_int8_model.tar.gz" )
endif()
inference_analysis_api_int8_test(test_analyzer_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc)
#resnet101 int8
set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101")
if (NOT EXISTS ${INT8_RESNET101_MODEL_DIR})
......
......@@ -288,6 +288,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p));
}
void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const {
const bool is_test = ctx.Attr<bool>("is_test");
......@@ -325,7 +326,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool fuse_brelu = ctx.Attr<bool>("fuse_brelu");
float fuse_brelu_threshold = ctx.Attr<float>("fuse_brelu_threshold");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
bool unsigned_output = fuse_relu || fuse_brelu;
if (fuse_residual_conn) {
PADDLE_ENFORCE(force_fp32_output != true,
"residual fusion does not support force output with fp32");
......@@ -340,8 +343,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"dilation in convolution is not implemented yet");
PADDLE_ENFORCE(is_conv3d != true, "int8 does not support conv3d currently");
PADDLE_ENFORCE(fuse_brelu != true,
"int8 does not support conv/relu6 fusion currently");
const T* input_data = input->data<T>();
......@@ -356,10 +357,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
auto dst_dt = (fuse_relu) ? paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<uint8_t>::DataType)
: paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<int8_t>::DataType);
auto dst_dt = unsigned_output
? paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<uint8_t>::DataType)
: paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<int8_t>::DataType);
if (force_fp32_output) {
dst_dt = paddle::framework::ToMKLDNNDataType(
......@@ -377,13 +379,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
key.reserve(MaxKeyLength);
platform::ConvMKLDNNHandler::AppendKey(
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
input->format(), fuse_relu, fuse_residual_conn, false /*fuse_brelu*/,
input->format(), fuse_relu, fuse_residual_conn, fuse_brelu,
ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false;
std::shared_ptr<mkldnn::convolution_forward> conv_p = nullptr;
std::shared_ptr<mkldnn::memory> src_memory_p = nullptr;
std::shared_ptr<mkldnn::memory> user_src_memory_p = nullptr;
......@@ -456,6 +457,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward
// TODO(lidanqing): We use relu post-op instead of brelu post-op cause
// mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when
// v0.20 is enabled
if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32,
......@@ -463,16 +467,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
conv_pd = ConvFwdPrimitiveDesc(
src_md, weights_md, bias_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn, false /*fuse_brelu*/,
0.0 /*fuse_brelu_threshold*/, output_shift_scale, sum_scale,
is_test);
mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
output_shift_scale, sum_scale, is_test);
} else {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
paddings, mkldnn_engine, fuse_relu,
fuse_residual_conn, false /*fuse_brelu*/,
0.0 /*fuse_brelu_threshold*/,
output_shift_scale, sum_scale, is_test);
conv_pd = ConvFwdPrimitiveDesc(
src_md, weights_md, dst_md, strides, paddings, mkldnn_engine,
fuse_relu || fuse_brelu /*fuse_relu*/, fuse_residual_conn,
false /*fuse_brelu*/, fuse_brelu_threshold, output_shift_scale,
sum_scale, is_test);
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
......@@ -514,7 +518,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
ctx, output, residual_param, user_residual_md, handler,
&pipeline);
} else {
need_s8_to_u8 = fuse_relu;
need_s8_to_u8 = unsigned_output;
dst_memory_p = platform::SetDstMemory<int8_t>(
ctx, output, residual_param, user_residual_md, handler,
&pipeline);
......@@ -525,12 +529,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_memory_p =
platform::SetDstMemory<uint8_t>(ctx, output, handler);
} else {
need_s8_to_u8 = fuse_relu;
need_s8_to_u8 = unsigned_output;
dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler);
}
}
} else if (!force_fp32_output) {
if (fuse_relu) {
if (unsigned_output) {
dst_memory_p = platform::SetDstMemory<uint8_t>(ctx, output, handler);
} else {
dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler);
......@@ -602,12 +606,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler,
&dst_memory_p);
} else {
need_s8_to_u8 = fuse_relu;
need_s8_to_u8 = unsigned_output;
platform::SetDstMemoryHandler<int8_t>(ctx, output, handler,
&dst_memory_p);
}
} else if (!force_fp32_output) {
if (fuse_relu) {
if (unsigned_output) {
platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler,
&dst_memory_p);
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册