From 93c4ee01074691a6b3786791ababd6f570ebec7f Mon Sep 17 00:00:00 2001 From: xiaolil1 Date: Wed, 7 Nov 2018 14:40:35 +0800 Subject: [PATCH] integrate residual different format fix to INT8 --- paddle/fluid/operators/conv_mkldnn_op.cc | 145 ++++++++++++----------- 1 file changed, 78 insertions(+), 67 deletions(-) diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 29008567546..f1ecfe41b96 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -535,27 +535,80 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::shared_ptr dst_memory_p; bool need_s8_to_u8 = false; - if(is_INT8){ - if (fuse_residual_conn) { - auto residual_param = ctx.Input("ResidualData"); - PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), - "Output and elementwise parameter need to have the " - "same dimension sizes"); - output->ShareDataWith(*residual_param); - auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type()); - if(residual_dt == mkldnn::memory::data_type::u8){ - - uint8_t* output_data = output->mutable_data(ctx.GetPlace()); - dst_memory_p = - handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); - } else{ - int8_t* output_data = output->mutable_data(ctx.GetPlace()); - dst_memory_p = - handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); - if(fuse_relu) - need_s8_to_u8 = true; - } + if(fuse_residual_conn) { + auto residual_param = ctx.Input("ResidualData"); + PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), + "Output and elementwise parameter need to have the " + "same dimension sizes"); + auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type()); + if(residual_param->format() != handler.GetDstFormat()) { + auto residual_data_tz = + paddle::framework::vectorize2int(residual_param->dims()); + auto residual_data_type = + paddle::framework::ToMKLDNNDataType(residual_param->type()); + auto user_residual_md = platform::MKLDNNMemDesc( + residual_data_tz, residual_data_type, residual_param->format()); + if(is_INT8){ + if(residual_dt == mkldnn::memory::data_type::u8){ + auto residual_param_data = residual_param->data(); + auto user_residual_memory_p = handler.AcquireResidualDataMemory( + user_residual_md, to_void_cast(residual_param_data)); + PADDLE_ENFORCE( + residual_param_data != nullptr, + "Provide data if you want MKLDNN conv+elementwise_add fusion"); + uint8_t* output_data = output->mutable_data(ctx.GetPlace()); + dst_memory_p = + handler.AcquireDstMemoryFromResidualDataMemory( + user_residual_memory_p, to_void_cast(output_data), pipeline); + } else{ + auto residual_param_data = residual_param->data(); + auto user_residual_memory_p = handler.AcquireResidualDataMemory( + user_residual_md, to_void_cast(residual_param_data)); + PADDLE_ENFORCE( + residual_param_data != nullptr, + "Provide data if you want MKLDNN conv+elementwise_add fusion"); + int8_t* output_data = output->mutable_data(ctx.GetPlace()); + dst_memory_p = + handler.AcquireDstMemoryFromResidualDataMemory( + user_residual_memory_p, to_void_cast(output_data), pipeline); + if(fuse_relu) + need_s8_to_u8 = true; + } + } else{ + auto residual_param_data = residual_param->data(); + auto user_residual_memory_p = handler.AcquireResidualDataMemory( + user_residual_md, to_void_cast(residual_param_data)); + PADDLE_ENFORCE( + residual_param_data != nullptr, + "Provide data if you want MKLDNN conv+elementwise_add fusion"); + auto output_data = + output->mutable_data(ctx.GetPlace(), handler.GetDstMemorySize()); + dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory( + user_residual_memory_p, to_void_cast(output_data), pipeline); + } } else { + output->ShareDataWith(*residual_param); + if(is_INT8){ + if(residual_dt == mkldnn::memory::data_type::u8){ + + uint8_t* output_data = output->mutable_data(ctx.GetPlace()); + dst_memory_p = + handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); + } else{ + int8_t* output_data = output->mutable_data(ctx.GetPlace()); + dst_memory_p = + handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); + if(fuse_relu) + need_s8_to_u8 = true; + } + } else{ + auto output_data = output->mutable_data(ctx.GetPlace()); + dst_memory_p = + handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); + } + } + } else { + if(is_INT8){ if(fuse_relu){ uint8_t* output_data = output->mutable_data(ctx.GetPlace(), handler.GetDstMemorySize()); dst_memory_p = @@ -565,54 +618,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { dst_memory_p = handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); } + } else{ + auto output_data = + output->mutable_data(ctx.GetPlace(), handler.GetDstMemorySize()); + dst_memory_p = + handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); } - } else{ - // create reorder primitive if the input format is not the preferred one - // auto src_memory_p = - // handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); - // auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( - // user_weights_memory_p, pipeline, is_test); - // std::shared_ptr dst_memory_p; - - if (fuse_residual_conn) { - auto residual_param = ctx.Input("ResidualData"); - auto residual_param_data = residual_param->data(); - - PADDLE_ENFORCE( - residual_param_data != nullptr, - "Provide data if you want MKLDNN conv+elementwise_add fusion"); - PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), - "Output and elementwise parameter need to have the " - "same dimension sizes"); - - if (residual_param->format() != handler.GetDstFormat()) { - auto output_data = - output->mutable_data(ctx.GetPlace(), handler.GetDstMemorySize()); - auto residual_data_tz = - paddle::framework::vectorize2int(residual_param->dims()); - auto residual_data_type = - paddle::framework::ToMKLDNNDataType(residual_param->type()); - auto user_residual_md = platform::MKLDNNMemDesc( - residual_data_tz, residual_data_type, residual_param->format()); - auto user_residual_memory_p = handler.AcquireResidualDataMemory( - user_residual_md, to_void_cast(residual_param_data)); - dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory( - user_residual_memory_p, to_void_cast(output_data), pipeline); - } else { - output->ShareDataWith(*residual_param); - auto output_data = output->mutable_data(ctx.GetPlace()); - dst_memory_p = - handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); - } - } else { - - auto output_data = - output->mutable_data(ctx.GetPlace(), handler.GetDstMemorySize()); - dst_memory_p = - handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); - } - // dst_memory_p = - // handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); } // create convolution op primitive -- GitLab