提交 2deac4e4 编写于 作者: G guomingz 提交者: Tao Luo

Fix the bug of test_conv2d_int8_mkldnn case which raised by improper parameter passing (#17058)

* resolve #17057

Fixed the bug that fuse_relu/fuse_residual option couldn't be passed to class TestConv2dInt8Op.

test=develop

* Fix the bug of test_conv2d_int8_mkldnn case which raised by improper parameter passing.

test=develop
上级 d9cd9898
...@@ -102,25 +102,26 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -102,25 +102,26 @@ class TestConv2dInt8Op(TestConv2dOp):
output1 = conv2d_forward_refer( output1 = conv2d_forward_refer(
input.astype(np.int32), filter_int, self.groups, input.astype(np.int32), filter_int, self.groups,
conv2d_param).astype(np.float32) conv2d_param).astype(np.float32)
output1_tmp = np.round(output1 * (
self.scale_out / (self.scale_in * self.scale_weights[0])))
if self.fuse_residual: if self.fuse_residual:
input_residual = np.random.randint( input_residual = np.random.randint(
0, 10, self.input_residual_size).astype(self.srctype) 0, 10, self.input_residual_size).astype(self.srctype)
output_tmp = np.round(output1 * (self.scale_out / ( output_tmp_res = np.round(output1 * (self.scale_out / (
self.scale_in * self.scale_weights[0])) + format_reorder( self.scale_in * self.scale_weights[0])) + format_reorder(
input_residual, self.input_residual_size).astype( input_residual, self.input_residual_size).astype(
np.int32) * (self.scale_out / self.scale_in_eltwise np.int32) * (self.scale_out / self.scale_in_eltwise
)) ))
output_tmp2 = np.round(output1 * (
self.scale_out / (self.scale_in * self.scale_weights[0])))
if self.fuse_relu: if self.fuse_relu:
output = np.maximum(output_tmp, 0).astype(self.dsttype) output = np.maximum(output_tmp_res, 0).astype(self.dsttype)
else: else:
output = output_tmp.astype(self.dsttype) output = output_tmp_res.astype(self.dsttype)
else: else:
if self.fuse_relu: if self.fuse_relu:
output = np.maximum(output_tmp2, 0).astype(self.dsttype) output = np.maximum(output1_tmp, 0).astype(self.dsttype)
else: else:
output = output_tmp2.astype(self.dsttype) output = output1_tmp.astype(self.dsttype)
self.inputs = { self.inputs = {
'Input': 'Input':
...@@ -265,11 +266,9 @@ def init_data_type_with_fusion(self, input_dt, fuse_relu, fuse_residual): ...@@ -265,11 +266,9 @@ def init_data_type_with_fusion(self, input_dt, fuse_relu, fuse_residual):
self.srctype = input_dt self.srctype = input_dt
self.dsttype = np.uint8 if fuse_relu else np.int8 self.dsttype = np.uint8 if fuse_relu else np.int8
def init_fuse_relu(self): self.fuse_relu = fuse_relu
self.fuse_relu = fuse_relu
def init_fuse_residual(self): self.fuse_residual = fuse_residual
self.fuse_residual = fuse_residual
def create_test_int8_class(parent): def create_test_int8_class(parent):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册