From 7afea1b6b0a14c293b2222c4e5d2a2a9490eb5f7 Mon Sep 17 00:00:00 2001 From: Haihao Shen Date: Fri, 2 Nov 2018 13:31:33 +0800 Subject: [PATCH] Fix the potential issue on conversion from s8 to u8 --- paddle/fluid/operators/conv_mkldnn_op.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index ca7275d044e..0393daf4ac3 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -479,6 +479,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } std::shared_ptr dst_memory_p; + bool need_s8_to_u8 = false; if(is_INT8){ if (fuse_residual_conn) { auto residual_param = ctx.Input("ResidualData"); @@ -496,6 +497,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { int8_t* output_data = output->mutable_data(ctx.GetPlace()); dst_memory_p = handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); + if(fuse_relu) + need_s8_to_u8 = true; } } else { if(fuse_relu){ @@ -565,7 +568,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { pipeline.push_back(*conv_p); stream(stream::kind::eager).submit(pipeline).wait(); - if(is_INT8 && fuse_residual_conn){ + if(need_s8_to_u8){ output->mutable_data(ctx.GetPlace()); } -- GitLab