提交 bce72c7f 编写于 作者: J joanna.wozna.intel 提交者: Tao Luo

Replace Relu with bounded Relu in MobileNetV2 quantization (#18988)

test=develop
上级 e044e842
......@@ -208,6 +208,15 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
DequantizeOutput(g, conv_op, conv_output, "Output", output_scale,
is_output_unsigned, "Scale_out");
// change threshold in bounded ReLu
if (conv_op->Op()->HasAttr("fuse_brelu") &&
boost::get<bool>(conv_op->Op()->GetAttr("fuse_brelu"))) {
float scale_out = boost::get<float>(conv_op->Op()->GetAttr("Scale_out"));
float threshold =
boost::get<float>(conv_op->Op()->GetAttr("fuse_brelu_threshold"));
conv_op->Op()->SetAttr("fuse_brelu_threshold", scale_out * threshold);
}
++quantize_conv_count;
};
......
......@@ -484,9 +484,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.reset(
new platform::ConvMKLDNNHandler(dev_ctx, mkldnn_engine, key));
// create a conv primitive descriptor and save it for usage in backward
// TODO(lidanqing): We use relu post-op instead of brelu post-op cause
// mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when
// v0.20 is enabled
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
......@@ -496,15 +493,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::format::x);
conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
propagation, output_shift_scale, sum_scale);
mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_brelu,
fuse_brelu_threshold, propagation, output_shift_scale, sum_scale);
} else {
conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, boost::none, dst_md, strides, paddings,
mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
propagation, output_shift_scale, sum_scale);
mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_brelu,
fuse_brelu_threshold, propagation, output_shift_scale, sum_scale);
}
// create mkldnn memory from input tensors (data/weights)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册