提交 cf7b3253 编写于 作者: F FDInSky 提交者: Tao Luo

fix conv2d_transpose op unittest (#21927)

上级 f4013ecb
...@@ -60,8 +60,10 @@ class TestConv2dTransposeMKLDNNOp(TestConv2dTransposeOp): ...@@ -60,8 +60,10 @@ class TestConv2dTransposeMKLDNNOp(TestConv2dTransposeOp):
f_c = self.input_size[1] f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3] self.filter_size = [f_c, 6, 3, 3]
self.groups = 1 self.groups = 1
self.dtype = np.float32
def setUp(self): def setUp(self):
TestConv2dTransposeOp.setUp(self) TestConv2dTransposeOp.setUp(self)
output = self.outputs['Output'] output = self.outputs['Output']
......
...@@ -78,7 +78,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs): ...@@ -78,7 +78,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs):
out_h = output_size[0] + pad_h_0 + pad_h_1 out_h = output_size[0] + pad_h_0 + pad_h_1
out_w = output_size[1] + pad_w_0 + pad_w_1 out_w = output_size[1] + pad_w_0 + pad_w_1
out = np.zeros((in_n, out_c, out_h, out_w)) out = np.zeros((in_n, out_c, out_h, out_w), dtype=input_.dtype)
for n in range(in_n): for n in range(in_n):
for i in range(in_h): for i in range(in_h):
...@@ -108,7 +108,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs): ...@@ -108,7 +108,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs):
class TestConv2dTransposeOp(OpTest): class TestConv2dTransposeOp(OpTest):
def setUp(self): def setUp(self):
# init as conv transpose # init as conv transpose
self.dtype = np.float32 self.dtype = np.float64
self.is_test = False self.is_test = False
self.use_cudnn = False self.use_cudnn = False
self.use_mkldnn = False self.use_mkldnn = False
...@@ -119,8 +119,8 @@ class TestConv2dTransposeOp(OpTest): ...@@ -119,8 +119,8 @@ class TestConv2dTransposeOp(OpTest):
self.init_op_type() self.init_op_type()
self.init_test_case() self.init_test_case()
input_ = np.random.random(self.input_size).astype("float32") input_ = np.random.random(self.input_size).astype(self.dtype)
filter_ = np.random.random(self.filter_size).astype("float32") filter_ = np.random.random(self.filter_size).astype(self.dtype)
self.inputs = {'Input': input_, 'Filter': filter_} self.inputs = {'Input': input_, 'Filter': filter_}
self.attrs = { self.attrs = {
...@@ -138,7 +138,7 @@ class TestConv2dTransposeOp(OpTest): ...@@ -138,7 +138,7 @@ class TestConv2dTransposeOp(OpTest):
self.attrs['output_size'] = self.output_size self.attrs['output_size'] = self.output_size
output = conv2dtranspose_forward_naive(input_, filter_, output = conv2dtranspose_forward_naive(input_, filter_,
self.attrs).astype('float32') self.attrs).astype(self.dtype)
self.outputs = {'Output': output} self.outputs = {'Output': output}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册