提交 37de7755 编写于 作者: Z zhaojiaying01

fix im2col in case of 2*2 input

上级 17c5c289
...@@ -78,7 +78,7 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> { ...@@ -78,7 +78,7 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
(((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0)); (((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0));
int fill = isize % 2; int fill = isize % 2;
if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 && if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 &&
dilation[0] == 1) { dilation[0] == 1 && im_height > 2) {
for (int c = 0; c < im_channels; ++c) { for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize; int oosize = osize * osize;
int nk4 = osize / 4; int nk4 = osize / 4;
...@@ -250,7 +250,7 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> { ...@@ -250,7 +250,7 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
im_data += isize * isize; im_data += isize * isize;
} }
} else if (stride[0] == 2 && filter_height == 3 && pad1 && } else if (stride[0] == 2 && filter_height == 3 && pad1 &&
dilation[0] == 1) { dilation[0] == 1 && im_height > 2) {
for (int c = 0; c < im_channels; ++c) { for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize; int oosize = osize * osize;
int nk4 = osize / 4; int nk4 = osize / 4;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册