From 786a72cd2cb33aaad06da3693cbac4325235e304 Mon Sep 17 00:00:00 2001 From: Haihao Shen Date: Fri, 9 Nov 2018 20:40:01 +0800 Subject: [PATCH] Fix force_fp32_output to correct the dst_memory_pd --- paddle/fluid/operators/conv_mkldnn_op.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 385e1a798fb..8bcd9930a83 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -413,6 +413,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { scale_weights_data[i] =*(scale_weights->data() + i); } scale_out_data = {*(scale_out->data())}; + if(force_fp32_output) + scale_out_data[0] = 1.0; output_shift_scale.resize(count); #pragma omp parallel for if (count > 1) for(int i=0; i { output->ShareDataWith(*residual_param); if(is_INT8){ if(residual_dt == mkldnn::memory::data_type::u8){ - uint8_t* output_data = output->mutable_data(ctx.GetPlace()); dst_memory_p = handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); @@ -712,7 +713,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } } } else { - if(is_INT8){ + if(is_INT8 && !force_fp32_output){ if(fuse_relu){ uint8_t* output_data = output->mutable_data(ctx.GetPlace(), handler.GetDstMemorySize()); dst_memory_p = -- GitLab