提交 660df122 编写于 作者: T tensor-tang

enable padding!=0 and fill height padding with 0

上级 d8e00fac
...@@ -48,13 +48,12 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -48,13 +48,12 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col->data<T>(); T* col_data = col->data<T>();
// TODO(TJ): change me to template // TODO(TJ): change me to template
// further optimaze: // further optimize: padding == 1 need special
// 1. padding != 1
// 2. could also support stride_h != 1
if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 && if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 &&
dilation[1] == 1 && padding[0] == 0 && padding[1] == 0) { dilation[1] == 1) {
int col_matrix_width = output_width * output_height; int col_matrix_width = output_width * output_height;
int im_size = im_height * im_width; int im_size = im_height * im_width;
if (padding[0] == 0 && padding[1] == 0) {
size_t copy_size = sizeof(T) * output_width; size_t copy_size = sizeof(T) * output_width;
for (int oh = 0; oh < output_height; ++oh) { for (int oh = 0; oh < output_height; ++oh) {
const T* im_data_start = im_data + oh * im_width; const T* im_data_start = im_data + oh * im_width;
...@@ -71,6 +70,41 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -71,6 +70,41 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
} }
} }
return; return;
} else {
int plh = padding[0];
// int plw = padding[1];
int prh =
(output_height - 1) * stride[0] + filter_height - im_height - plh;
// int prw = (output_width - 1) * stride[1] + filter_width - im_width -
// plw;
// fill height padding : 0 ~ plh-1, (oh-prh) ~ (oh-1)
// TODO(TJ): reuse sizes
assert(plh == prh); // because stride_h == 1
for (int ph = 0; ph < plh; ++ph) {
size_t sz = sizeof(T) * output_width * (plh - ph);
T* col_start_l = col_data + ph * filter_width * col_matrix_width;
T* col_start_r =
col_data +
(filter_width - ph - 1) * filter_width * col_matrix_width +
col_matrix_width - output_width * (plh - ph);
for (int ic = 0; ic < im_channels; ++ic) {
T* dst_data_l =
col_start_l +
ic * filter_width * filter_height * col_matrix_width;
T* dst_data_r =
col_start_r +
ic * filter_width * filter_height * col_matrix_width;
for (int kw = 0; kw < filter_width; ++kw) {
std::memset(dst_data_l, 0, sz);
std::memset(dst_data_r, 0, sz);
dst_data_l = dst_data_l + col_matrix_width;
dst_data_r = dst_data_r + col_matrix_width;
}
}
}
return;
}
} }
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册