From c8f101e5da3497bfa12688d90d84cad52deee2f0 Mon Sep 17 00:00:00 2001 From: xiaolil1 <39753926+xiaolil1@users.noreply.github.com> Date: Mon, 7 Jan 2019 19:55:08 +0800 Subject: [PATCH] Conv int8 relu (#15130) * Enable basic MKL-DNN INT8 Conv OP test=develop * Modify test case test=develop * Clean unittest code test=develop * Fix test test=develop * Modify test test=develop * Enable MKL-DNN INT8 Conv with Relu Fusion OP test=develop * Modify basic INT8 Conv test=develop * fix type test=develop * Modify test test=develop --- paddle/fluid/operators/conv_mkldnn_op.cc | 69 ++++++++++++------ paddle/fluid/platform/mkldnn_reuse.h | 8 ++- .../unittests/test_conv2d_int8_mkldnn_op.py | 70 +++++++++++++++---- 3 files changed, 107 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 0f2bb8c65cf..03d9d466c32 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -319,6 +319,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector dilations = ctx.Attr>("dilations"); int groups = ctx.Attr("groups"); + bool fuse_relu = ctx.Attr("fuse_relu"); + bool force_fp32_output = ctx.Attr("force_fp32_output"); bool is_conv3d = strides.size() == 3U; @@ -329,6 +331,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { dilations[2] == 1 : dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1, "dilation in convolution is not implemented yet"); + PADDLE_ENFORCE(is_conv3d != true, "int8 does not support conv3d currently"); const T* input_data = input->data(); @@ -340,15 +343,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { GetWeightsTz(weights_tz, g, is_conv3d); std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + mkldnn::memory::data_type src_dt = + paddle::framework::ToMKLDNNDataType(input->type()); + auto dst_dt = fuse_relu ? paddle::framework::ToMKLDNNDataType( + framework::DataTypeTrait::DataType) + : paddle::framework::ToMKLDNNDataType( + framework::DataTypeTrait::DataType); + + if (force_fp32_output) { + dst_dt = paddle::framework::ToMKLDNNDataType( + framework::DataTypeTrait::DataType); + } + // Get unique name for storing MKLDNN primitives std::string key; key.reserve(MaxKeyLength); - mkldnn::memory::data_type src_dt = - paddle::framework::ToMKLDNNDataType(input->type()); platform::ConvMKLDNNHandler::AppendKey( &key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt, - input->format(), ctx.op().Output("Output")); - + input->format(), dst_dt, ctx.op().Output("Output")); const std::string key_conv_pd = key + "@conv_pd"; std::shared_ptr conv_p = nullptr; @@ -413,13 +425,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format); auto weights_md = platform::MKLDNNMemDesc( weights_tz, memory::data_type::s8, chosen_memory_format); - - auto dst_dt = force_fp32_output - ? paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType) - : paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType); - auto dst_md = platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format); // create a conv primitive descriptor and save it for usage in backward @@ -429,11 +434,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { memory::format::x); conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, - output_shift_scale, is_test); + fuse_relu, output_shift_scale, is_test); } else { - conv_pd = - ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, - mkldnn_engine, output_shift_scale, is_test); + conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, + paddings, mkldnn_engine, fuse_relu, + output_shift_scale, is_test); } // Save conv_pd/src_memory/weights_memory for backward pass dev_ctx.SetBlob(key_conv_pd, conv_pd); @@ -459,7 +464,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { mask_reorder); if (!force_fp32_output) { - dst_memory_p = platform::SetDstMemory(ctx, output, handler); + if (fuse_relu) { + dst_memory_p = platform::SetDstMemory(ctx, output, handler); + } else { + dst_memory_p = platform::SetDstMemory(ctx, output, handler); + } } else { dst_memory_p = platform::SetDstMemory(ctx, output, handler); } @@ -518,8 +527,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { mkldnn_engine, key)); } if (!force_fp32_output) { - dst_memory_p = - platform::SetDstMemoryHandler(ctx, output, handler); + if (fuse_relu) { + dst_memory_p = + platform::SetDstMemoryHandler(ctx, output, handler); + } else { + dst_memory_p = + platform::SetDstMemoryHandler(ctx, output, handler); + } } else { dst_memory_p = platform::SetDstMemoryHandler(ctx, output, handler); @@ -563,11 +577,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } mkldnn::primitive_attr CreatePostOps( - const std::vector output_shift_scale) const { + bool fuse_relu, const std::vector output_shift_scale) const { mkldnn::primitive_attr conv_attr; mkldnn::post_ops post_operations; int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0; conv_attr.set_output_scales(mask, output_shift_scale); + if (fuse_relu) { + constexpr float scale = 1.0f; + constexpr float negative_slope = 0.0f; + constexpr float placeholder = 1.0f; // beta + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, + negative_slope, placeholder); + } conv_attr.set_post_ops(post_operations); return conv_attr; } @@ -600,7 +621,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, const memory::desc& dst, const std::vector& strides, const std::vector& paddings, - const mkldnn::engine& engine, + const mkldnn::engine& engine, const bool fuse_relu, const std::vector output_shift_scale, bool is_test) const { memory::dims stride_dims = {strides[0], strides[1]}; @@ -613,7 +634,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - mkldnn::primitive_attr conv_attr = CreatePostOps(output_shift_scale); + mkldnn::primitive_attr conv_attr = + CreatePostOps(fuse_relu, output_shift_scale); auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( conv_desc, conv_attr, engine); @@ -652,7 +674,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const memory::desc& bias, const memory::desc& dst, const std::vector& strides, const std::vector& paddings, - const mkldnn::engine& engine, + const mkldnn::engine& engine, const bool fuse_relu, const std::vector output_shift_scale, bool is_test) const { memory::dims stride_dims = {strides[0], strides[1]}; @@ -665,7 +687,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { propagation, mkldnn::convolution_direct, src, weights, bias, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - mkldnn::primitive_attr conv_attr = CreatePostOps(output_shift_scale); + mkldnn::primitive_attr conv_attr = + CreatePostOps(fuse_relu, output_shift_scale); auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( conv_desc, conv_attr, engine); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 98d1242a169..b3d20736a8e 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -214,16 +214,18 @@ class MKLDNNHandler { std::string* key, const mkldnn::memory::dims& input_dims, const mkldnn::memory::dims& weights_dims, const std::vector& strides, const std::vector& paddings, const std::vector& dilations, - const int& groups, const mkldnn::memory::data_type& type, - const mkldnn::memory::format& format, const std::string& suffix) { + const int& groups, const mkldnn::memory::data_type& srcdt, + const mkldnn::memory::format& format, + const mkldnn::memory::data_type& dstdt, const std::string& suffix) { AppendKeyDims(key, input_dims); AppendKeyDims(key, weights_dims); AppendKeyVec(key, strides); AppendKeyVec(key, paddings); AppendKeyVec(key, dilations); AppendKey(key, std::to_string(groups)); - AppendKey(key, std::to_string(type)); + AppendKey(key, std::to_string(srcdt)); AppendKey(key, std::to_string(format)); + AppendKey(key, std::to_string(dstdt)); AppendKey(key, suffix); } diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_int8_mkldnn_op.py index ca35adc1a36..def188bfa63 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_int8_mkldnn_op.py @@ -47,7 +47,8 @@ class TestConv2dInt8Op(TestConv2dOp): self.init_group() self.init_dilation() self.init_test_case() - self.init_dtype() + self.init_fuse_relu() + self.init_data_type() conv2d_param = { 'stride': self.stride, @@ -78,7 +79,11 @@ class TestConv2dInt8Op(TestConv2dOp): np.round((input_shift) * self.scale_in).astype(np.int32), filter_int, self.groups, conv2d_param).astype(np.float32) * scale_output_shift - output = np.round(output1 - output2).astype(self.dsttype) + if self.fuse_relu: + output = np.maximum(np.round(output1 - output2), + 0).astype(self.dsttype) + else: + output = np.round(output1 - output2).astype(self.dsttype) else: filter_int = np.round(filter * self.scale_weights[0]).astype(np.int32) @@ -87,7 +92,15 @@ class TestConv2dInt8Op(TestConv2dOp): output1 = conv2d_forward_refer( input.astype(np.int32), filter_int, self.groups, conv2d_param).astype(np.float32) - output = np.round(output1 * scale_output_shift).astype(self.dsttype) + if self.fuse_relu: + output = np.maximum( + np.round(output1 * (self.scale_out / ( + self.scale_in * self.scale_weights[0]))), + 0).astype(self.dsttype) + else: + output = np.round(output1 * (self.scale_out / ( + self.scale_in * + self.scale_weights[0]))).astype(self.dsttype) self.inputs = { 'Input': @@ -106,6 +119,7 @@ class TestConv2dInt8Op(TestConv2dOp): 'Scale_in': self.scale_in, 'Scale_out': self.scale_out, 'Scale_weights': self.scale_weights, + 'fuse_relu': self.fuse_relu } self.outputs = {'Output': output} @@ -129,12 +143,15 @@ class TestConv2dInt8Op(TestConv2dOp): self.scale_out = 0.5 self.scale_weights = [10.0] - def init_dtype(self): + def init_data_type(self): self.srctype = np.uint8 self.dsttype = np.int8 + def init_fuse_relu(self): + self.fuse_relu = True -#--------------------test conv2d u8 in and s8 out-------------------- + +#--------------------test conv2d u8 in and u8 out-------------------- class TestConv2d(TestConv2dInt8Op): @@ -203,18 +220,43 @@ class TestWithInput1x1Filter1x1(TestConv2dInt8Op): self.groups = 3 -#--------------------test conv2d s8 in and s8 out-------------------- +def init_data_type_with_fusion(self, input_dt, fuse_relu): + self.srctype = input_dt + self.dsttype = np.uint8 if fuse_relu else np.int8 + + def init_fuse_relu(self): + self.fuse_relu = fuse_relu def create_test_int8_class(parent): - class TestInt8Case(parent): - def init_dtype(self): - self.srctype = np.int8 - self.dsttype = np.int8 - - cls_name = "{0}_{1}".format(parent.__name__, "s8s8") - TestInt8Case.__name__ = cls_name - globals()[cls_name] = TestInt8Case + + #--------------------test conv2d s8 in and u8 out-------------------- + + class TestS8U8Case(parent): + def init_data_type(self): + init_data_type_with_fusion(self, np.int8, True) + + #--------------------test conv2d s8 in and s8 out-------------------- + + class TestS8S8Case(parent): + def init_data_type(self): + init_data_type_with_fusion(self, np.int8, False) + + #--------------------test conv2d u8 in and s8 out-------------------- + + class TestU8S8Case(parent): + def init_data_type(self): + init_data_type_with_fusion(self, np.uint8, False) + + cls_name_s8u8 = "{0}_relu_{1}".format(parent.__name__, "1") + cls_name_s8s8 = "{0}_relu_{1}".format(parent.__name__, "0") + cls_name_u8s8 = "{0}_relu_{1}".format(parent.__name__, "0") + TestS8U8Case.__name__ = cls_name_s8u8 + TestS8S8Case.__name__ = cls_name_s8s8 + TestU8S8Case.__name__ = cls_name_u8s8 + globals()[cls_name_s8u8] = TestS8U8Case + globals()[cls_name_s8s8] = TestS8S8Case + globals()[cls_name_u8s8] = TestU8S8Case create_test_int8_class(TestConv2dInt8Op) -- GitLab