diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index b0585ddb3344d3b045edf9c4d4e1852f8cfbcf70..43bfb821ae5640cdf70d0e22c8f865205f2fe87c 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -339,6 +339,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector paddings = ctx.Attr>("paddings"); std::vector dilations = ctx.Attr>("dilations"); bool fuse_relu = ctx.Attr("fuse_relu"); + bool force_fp32_output = ctx.Attr("force_fp32_output"); bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); int groups = ctx.Attr("groups"); @@ -519,6 +520,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { if(dst_dt != residual_dt) dst_dt = residual_dt; } + if(force_fp32_output) + dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(float))); dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format))); mds[2] = src_md; mds[3] = weights_md;