diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py index c7b8a096bf1a7e2f5b63b136c7036edad863c888..b9ef447b56f1d05c574d3e80ed830ec0dd6638bf 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py @@ -102,25 +102,26 @@ class TestConv2dInt8Op(TestConv2dOp): output1 = conv2d_forward_refer( input.astype(np.int32), filter_int, self.groups, conv2d_param).astype(np.float32) + output1_tmp = np.round(output1 * ( + self.scale_out / (self.scale_in * self.scale_weights[0]))) + if self.fuse_residual: input_residual = np.random.randint( 0, 10, self.input_residual_size).astype(self.srctype) - output_tmp = np.round(output1 * (self.scale_out / ( + output_tmp_res = np.round(output1 * (self.scale_out / ( self.scale_in * self.scale_weights[0])) + format_reorder( input_residual, self.input_residual_size).astype( np.int32) * (self.scale_out / self.scale_in_eltwise )) - output_tmp2 = np.round(output1 * ( - self.scale_out / (self.scale_in * self.scale_weights[0]))) if self.fuse_relu: - output = np.maximum(output_tmp, 0).astype(self.dsttype) + output = np.maximum(output_tmp_res, 0).astype(self.dsttype) else: - output = output_tmp.astype(self.dsttype) + output = output_tmp_res.astype(self.dsttype) else: if self.fuse_relu: - output = np.maximum(output_tmp2, 0).astype(self.dsttype) + output = np.maximum(output1_tmp, 0).astype(self.dsttype) else: - output = output_tmp2.astype(self.dsttype) + output = output1_tmp.astype(self.dsttype) self.inputs = { 'Input': @@ -265,11 +266,9 @@ def init_data_type_with_fusion(self, input_dt, fuse_relu, fuse_residual): self.srctype = input_dt self.dsttype = np.uint8 if fuse_relu else np.int8 - def init_fuse_relu(self): - self.fuse_relu = fuse_relu + self.fuse_relu = fuse_relu - def init_fuse_residual(self): - self.fuse_residual = fuse_residual + self.fuse_residual = fuse_residual def create_test_int8_class(parent):