From 2e09ab369df448974b7af0a2e46ae7905bd11a99 Mon Sep 17 00:00:00 2001 From: Haihao Shen Date: Fri, 9 Nov 2018 10:49:44 +0800 Subject: [PATCH] Support force fp32 ouput for int8 conv --- paddle/fluid/operators/conv_mkldnn_op.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index b0585ddb334..43bfb821ae5 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -339,6 +339,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector paddings = ctx.Attr>("paddings"); std::vector dilations = ctx.Attr>("dilations"); bool fuse_relu = ctx.Attr("fuse_relu"); + bool force_fp32_output = ctx.Attr("force_fp32_output"); bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); int groups = ctx.Attr("groups"); @@ -519,6 +520,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { if(dst_dt != residual_dt) dst_dt = residual_dt; } + if(force_fp32_output) + dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(float))); dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format))); mds[2] = src_md; mds[3] = weights_md; -- GitLab