提交 5fe68931 编写于 作者: C chengduoZH

fix code struce

上级 8ad67da9
......@@ -21,41 +21,37 @@ def conv2d_forward_naive(input, filter, group, conv_param):
for i in range(out_h):
for j in range(out_w):
for g in range(group):
input_pad_masked = input_pad[:, g * f_c:(
g + 1) * f_c, i * stride[0]:i * stride[0] + f_h, j * stride[
1]:j * stride[1] + f_w]
input_pad_masked = \
input_pad[:, g * f_c:(g + 1) * f_c,
i * stride[0]:i * stride[0] + f_h,
j * stride[1]:j * stride[1] + f_w]
f_sub = filter[g * sub_out_c:(g + 1) * sub_out_c, :, :, :]
for k in range(sub_out_c):
out[:, g * sub_out_c + k, i, j] = np.sum(input_pad_masked *
f_sub[k, :, :, :],
axis=(1, 2, 3))
out[:, g * sub_out_c + k, i, j] = \
np.sum(input_pad_masked * f_sub[k, :, :, :],
axis=(1, 2, 3))
return out
class TestConv2dOp(OpTest):
def setUp(self):
self.init_groups()
self.init_optype()
pad = [0, 0]
stride = [1, 1]
input_size = [2, 3, 5, 5] # NCHW
assert np.mod(input_size[1], self.groups) == 0
f_c = input_size[1] / self.groups
filter_size = [6, f_c, 3, 3]
conv2d_param = {'stride': stride, 'pad': pad}
input = np.random.random(input_size).astype("float32")
filter = np.random.random(filter_size).astype("float32")
self.init_op_type()
self.init_group()
self.init_test_case()
conv2d_param = {'stride': self.stride, 'pad': self.pad}
input = np.random.random(self.input_size).astype("float32")
filter = np.random.random(self.filter_size).astype("float32")
output = conv2d_forward_naive(input, filter, self.groups, conv2d_param)
self.inputs = {'Input': input, 'Filter': filter}
self.attrs = {
'strides': stride,
'paddings': pad,
'strides': self.stride,
'paddings': self.pad,
'groups': self.groups,
'dilations': [1, 1]
'dilations': self.dilations
}
self.outputs = {'Output': output}
......@@ -80,30 +76,47 @@ class TestConv2dOp(OpTest):
max_relative_error=0.05,
no_grad_set=set(['Input']))
def init_groups(self):
def init_test_case(self):
self.groups = 1
self.op_type = "conv2d"
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 3, 3]
def init_group(self):
self.groups = 1
def init_optype(self):
def init_op_type(self):
self.op_type = "conv2d"
class TestWithGroup(TestConv2dOp):
def init_groups(self):
def init_group(self):
self.groups = 3
def init_op_type(self):
self.op_type = "conv2d"
class TestCudnn2d(TestConv2dOp):
def init_optype(self):
self.op_type = "conv_cudnn"
class TestCudnn(TestConv2dOp):
def init_group(self):
self.groups = 1
class TestCudnn2dWithGroup(TestConv2dOp):
def init_optype(self):
def init_op_type(self):
self.op_type = "conv_cudnn"
def init_groups(self):
class TestCudnnWithGroup(TestConv2dOp):
def init_group(self):
self.groups = 3
def init_op_type(self):
self.op_type = "conv_cudnn"
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册