提交 2deac4e4 编写于 作者: G guomingz 提交者: Tao Luo

Fix the bug of test_conv2d_int8_mkldnn case which raised by improper parameter passing (#17058)

* resolve #17057

Fixed the bug that fuse_relu/fuse_residual option couldn't be passed to class TestConv2dInt8Op.

test=develop

* Fix the bug of test_conv2d_int8_mkldnn case which raised by improper parameter passing.

test=develop
上级 d9cd9898
......@@ -102,25 +102,26 @@ class TestConv2dInt8Op(TestConv2dOp):
output1 = conv2d_forward_refer(
input.astype(np.int32), filter_int, self.groups,
conv2d_param).astype(np.float32)
output1_tmp = np.round(output1 * (
self.scale_out / (self.scale_in * self.scale_weights[0])))
if self.fuse_residual:
input_residual = np.random.randint(
0, 10, self.input_residual_size).astype(self.srctype)
output_tmp = np.round(output1 * (self.scale_out / (
output_tmp_res = np.round(output1 * (self.scale_out / (
self.scale_in * self.scale_weights[0])) + format_reorder(
input_residual, self.input_residual_size).astype(
np.int32) * (self.scale_out / self.scale_in_eltwise
))
output_tmp2 = np.round(output1 * (
self.scale_out / (self.scale_in * self.scale_weights[0])))
if self.fuse_relu:
output = np.maximum(output_tmp, 0).astype(self.dsttype)
output = np.maximum(output_tmp_res, 0).astype(self.dsttype)
else:
output = output_tmp.astype(self.dsttype)
output = output_tmp_res.astype(self.dsttype)
else:
if self.fuse_relu:
output = np.maximum(output_tmp2, 0).astype(self.dsttype)
output = np.maximum(output1_tmp, 0).astype(self.dsttype)
else:
output = output_tmp2.astype(self.dsttype)
output = output1_tmp.astype(self.dsttype)
self.inputs = {
'Input':
......@@ -265,10 +266,8 @@ def init_data_type_with_fusion(self, input_dt, fuse_relu, fuse_residual):
self.srctype = input_dt
self.dsttype = np.uint8 if fuse_relu else np.int8
def init_fuse_relu(self):
self.fuse_relu = fuse_relu
def init_fuse_residual(self):
self.fuse_residual = fuse_residual
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册