提交 cf7b3253 编写于 作者: F FDInSky 提交者: Tao Luo

fix conv2d_transpose op unittest (#21927)

上级 f4013ecb
......@@ -60,8 +60,10 @@ class TestConv2dTransposeMKLDNNOp(TestConv2dTransposeOp):
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
self.groups = 1
self.dtype = np.float32
def setUp(self):
TestConv2dTransposeOp.setUp(self)
output = self.outputs['Output']
......
......@@ -78,7 +78,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs):
out_h = output_size[0] + pad_h_0 + pad_h_1
out_w = output_size[1] + pad_w_0 + pad_w_1
out = np.zeros((in_n, out_c, out_h, out_w))
out = np.zeros((in_n, out_c, out_h, out_w), dtype=input_.dtype)
for n in range(in_n):
for i in range(in_h):
......@@ -108,7 +108,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs):
class TestConv2dTransposeOp(OpTest):
def setUp(self):
# init as conv transpose
self.dtype = np.float32
self.dtype = np.float64
self.is_test = False
self.use_cudnn = False
self.use_mkldnn = False
......@@ -119,8 +119,8 @@ class TestConv2dTransposeOp(OpTest):
self.init_op_type()
self.init_test_case()
input_ = np.random.random(self.input_size).astype("float32")
filter_ = np.random.random(self.filter_size).astype("float32")
input_ = np.random.random(self.input_size).astype(self.dtype)
filter_ = np.random.random(self.filter_size).astype(self.dtype)
self.inputs = {'Input': input_, 'Filter': filter_}
self.attrs = {
......@@ -138,7 +138,7 @@ class TestConv2dTransposeOp(OpTest):
self.attrs['output_size'] = self.output_size
output = conv2dtranspose_forward_naive(input_, filter_,
self.attrs).astype('float32')
self.attrs).astype(self.dtype)
self.outputs = {'Output': output}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册