提交 3a48282e 编写于 作者: W wanghaoshuang

Fix unitest

上级 500e29a4
...@@ -64,7 +64,7 @@ class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -64,7 +64,7 @@ class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "(LodTensor)The output data of im2sequence op,"); AddOutput("Out", "(LodTensor)The output data of im2sequence op,");
AddAttr<std::vector<int>>("kernels", AddAttr<std::vector<int>>("kernels",
"(vector<int>), the " "(vector<int>), the "
"kernels(kernel_height, kernel_width)") "kernels(kernel_height, kernel_width)");
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the " "(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride)") "strides(h_stride, w_stride)")
......
...@@ -50,11 +50,11 @@ class Im2SequenceKernel : public framework::OpKernel<T> { ...@@ -50,11 +50,11 @@ class Im2SequenceKernel : public framework::OpKernel<T> {
int img_height = in_dim[2]; int img_height = in_dim[2];
int img_width = in_dim[3]; int img_width = in_dim[3];
auto kernels = ctx->Attrs().Get<std::vector<int>>("kernels"); auto kernels = ctx.Attr<std::vector<int>>("kernels");
auto strides = ctx->Attrs().Get<std::vector<int>>("strides"); auto strides = ctx.Attr<std::vector<int>>("strides");
auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); auto paddings = ctx.Attr<std::vector<int>>("paddings");
int output_height = int output_height = OutputSize(img_height, kernels[0], paddings[0],
OutputSize(img_height, kernels[0], paddings[0], paddings[2] strides[0]); paddings[2], strides[0]);
int output_width = int output_width =
OutputSize(img_width, kernels[1], paddings[1], paddings[3], strides[1]); OutputSize(img_width, kernels[1], paddings[1], paddings[3], strides[1]);
...@@ -106,9 +106,9 @@ class Im2SequenceGradKernel : public framework::OpKernel<T> { ...@@ -106,9 +106,9 @@ class Im2SequenceGradKernel : public framework::OpKernel<T> {
int img_height = in_dim[2]; int img_height = in_dim[2];
int img_width = in_dim[3]; int img_width = in_dim[3];
auto kernels = ctx->Attrs().Get<std::vector<int>>("kernels"); auto kernels = ctx.Attr<std::vector<int>>("kernels");
auto strides = ctx->Attrs().Get<std::vector<int>>("strides"); auto strides = ctx.Attr<std::vector<int>>("strides");
auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); auto paddings = ctx.Attr<std::vector<int>>("paddings");
int output_height = OutputSize(img_height, kernels[0], paddings[0], int output_height = OutputSize(img_height, kernels[0], paddings[0],
paddings[2], strides[0]); paddings[2], strides[0]);
int output_width = int output_width =
......
...@@ -20,22 +20,19 @@ def get_output_shape(attrs, in_shape): ...@@ -20,22 +20,19 @@ def get_output_shape(attrs, in_shape):
img_height = in_shape[2] img_height = in_shape[2]
img_width = in_shape[3] img_width = in_shape[3]
padding_height = attrs['padding_height'] paddings = attrs['paddings']
padding_width = attrs['padding_width'] kernels = attrs['kernels']
block_height = attrs['block_height'] strides = attrs['strides']
block_width = attrs['block_width']
stride_height = attrs['stride_height']
stride_width = attrs['stride_width']
output_height = \ output_height = \
1 + \ 1 + \
(img_height + 2 * padding_height - block_height + stride_height - 1) / \ (img_height + paddings[0] + paddings[2] - kernels[0] + strides[0] - 1) / \
stride_height strides[0]
output_width = \ output_width = \
1 + \ 1 + \
(img_width + 2 * padding_width - block_width + stride_width - 1) / \ (img_width + paddings[1] + paddings[3] - kernels[1] + strides[1] - 1) / \
stride_width strides[1]
return output_height, output_width return output_height, output_width
...@@ -46,19 +43,11 @@ def im2col(attrs, im, col): ...@@ -46,19 +43,11 @@ def im2col(attrs, im, col):
col: col:
{outputHeight, outputWidth, inputChannels, filterHeight, filterWidth} {outputHeight, outputWidth, inputChannels, filterHeight, filterWidth}
""" """
input_channels = im.shape[0] input_channels, input_height, input_width = im.shape
input_height = im.shape[1] output_height, output_width, _, filter_height, filter_width = col.shape
input_width = im.shape[2]
output_height = col.shape[0] stride_height, stride_width = attrs['strides']
output_width = col.shape[1] padding_height, padding_width = attrs['paddings'][0:2]
filter_height = col.shape[3]
filter_width = col.shape[4]
stride_height = attrs['stride_height']
stride_width = attrs['stride_width']
padding_height = attrs['padding_height']
padding_width = attrs['padding_width']
for col_row_idx in range(0, output_height): for col_row_idx in range(0, output_height):
for col_col_idx in range(0, output_width): for col_col_idx in range(0, output_width):
...@@ -92,7 +81,7 @@ def Im2Sequence(inputs, attrs): ...@@ -92,7 +81,7 @@ def Im2Sequence(inputs, attrs):
batch_size = inputs.shape[0] batch_size = inputs.shape[0]
out = np.zeros([ out = np.zeros([
batch_size, output_height, output_width, img_channels, batch_size, output_height, output_width, img_channels,
attrs['block_height'], attrs['block_width'] attrs['kernels'][0], attrs['kernels'][1]
]).astype("float32") ]).astype("float32")
for i in range(len(inputs)): for i in range(len(inputs)):
...@@ -100,7 +89,7 @@ def Im2Sequence(inputs, attrs): ...@@ -100,7 +89,7 @@ def Im2Sequence(inputs, attrs):
out = out.reshape([ out = out.reshape([
batch_size * output_height * output_width, batch_size * output_height * output_width,
img_channels * attrs['block_height'] * attrs['block_width'] img_channels * attrs['kernels'][0] * attrs['kernels'][1]
]) ])
return out return out
...@@ -112,12 +101,9 @@ class TestBlockExpandOp(OpTest): ...@@ -112,12 +101,9 @@ class TestBlockExpandOp(OpTest):
self.img_height = 4 self.img_height = 4
self.img_width = 4 self.img_width = 4
self.attrs = { self.attrs = {
'block_height': 2, 'kernels': [2, 2],
'block_width': 2, 'strides': [1, 1],
'stride_height': 1, 'paddings': [1, 1, 1, 1]
'stride_width': 1,
'padding_height': 1,
'padding_width': 1,
} }
def setUp(self): def setUp(self):
...@@ -145,12 +131,9 @@ class TestBlockExpandOpCase2(TestBlockExpandOp): ...@@ -145,12 +131,9 @@ class TestBlockExpandOpCase2(TestBlockExpandOp):
self.img_height = 4 self.img_height = 4
self.img_width = 5 self.img_width = 5
self.attrs = { self.attrs = {
'block_height': 2, 'kernels': [2, 1],
'block_width': 1, 'strides': [2, 1],
'stride_height': 2, 'paddings': [2, 1, 2, 1]
'stride_width': 1,
'padding_height': 2,
'padding_width': 1,
} }
...@@ -161,12 +144,9 @@ class TestBlockExpandOpCase3(TestBlockExpandOp): ...@@ -161,12 +144,9 @@ class TestBlockExpandOpCase3(TestBlockExpandOp):
self.img_height = 4 self.img_height = 4
self.img_width = 5 self.img_width = 5
self.attrs = { self.attrs = {
'block_height': 2, 'kernels': [2, 1],
'block_width': 1, 'strides': [2, 1],
'stride_height': 2, 'paddings': [2, 0, 2, 0]
'stride_width': 1,
'padding_height': 2,
'padding_width': 0,
} }
...@@ -177,12 +157,9 @@ class TestBlockExpandOpCase4(TestBlockExpandOp): ...@@ -177,12 +157,9 @@ class TestBlockExpandOpCase4(TestBlockExpandOp):
self.img_height = 3 self.img_height = 3
self.img_width = 3 self.img_width = 3
self.attrs = { self.attrs = {
'block_height': 2, 'kernels': [2, 2],
'block_width': 2, 'strides': [1, 1],
'stride_height': 1, 'paddings': [0, 0, 0, 0]
'stride_width': 1,
'padding_height': 0,
'padding_width': 0,
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册