From 2deac4e447095718a88bb74d1831c3bc018f7e75 Mon Sep 17 00:00:00 2001 From: guomingz Date: Wed, 24 Apr 2019 15:41:15 +0800 Subject: [PATCH] Fix the bug of test_conv2d_int8_mkldnn case which raised by improper parameter passing (#17058) * resolve #17057 Fixed the bug that fuse_relu/fuse_residual option couldn't be passed to class TestConv2dInt8Op. test=develop * Fix the bug of test_conv2d_int8_mkldnn case which raised by improper parameter passing. test=develop --- .../mkldnn/test_conv2d_int8_mkldnn_op.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py index c7b8a096bf1..b9ef447b56f 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py @@ -102,25 +102,26 @@ class TestConv2dInt8Op(TestConv2dOp): output1 = conv2d_forward_refer( input.astype(np.int32), filter_int, self.groups, conv2d_param).astype(np.float32) + output1_tmp = np.round(output1 * ( + self.scale_out / (self.scale_in * self.scale_weights[0]))) + if self.fuse_residual: input_residual = np.random.randint( 0, 10, self.input_residual_size).astype(self.srctype) - output_tmp = np.round(output1 * (self.scale_out / ( + output_tmp_res = np.round(output1 * (self.scale_out / ( self.scale_in * self.scale_weights[0])) + format_reorder( input_residual, self.input_residual_size).astype( np.int32) * (self.scale_out / self.scale_in_eltwise )) - output_tmp2 = np.round(output1 * ( - self.scale_out / (self.scale_in * self.scale_weights[0]))) if self.fuse_relu: - output = np.maximum(output_tmp, 0).astype(self.dsttype) + output = np.maximum(output_tmp_res, 0).astype(self.dsttype) else: - output = output_tmp.astype(self.dsttype) + output = output_tmp_res.astype(self.dsttype) else: if self.fuse_relu: - output = np.maximum(output_tmp2, 0).astype(self.dsttype) + output = np.maximum(output1_tmp, 0).astype(self.dsttype) else: - output = output_tmp2.astype(self.dsttype) + output = output1_tmp.astype(self.dsttype) self.inputs = { 'Input': @@ -265,11 +266,9 @@ def init_data_type_with_fusion(self, input_dt, fuse_relu, fuse_residual): self.srctype = input_dt self.dsttype = np.uint8 if fuse_relu else np.int8 - def init_fuse_relu(self): - self.fuse_relu = fuse_relu + self.fuse_relu = fuse_relu - def init_fuse_residual(self): - self.fuse_residual = fuse_residual + self.fuse_residual = fuse_residual def create_test_int8_class(parent): -- GitLab